diff --git a/ast_checker/engines/__init__.py b/ast_checker/engines/__init__.py new file mode 100644 index 0000000..9b88312 --- /dev/null +++ b/ast_checker/engines/__init__.py @@ -0,0 +1,21 @@ +from .function_call import CountFunctionCallEngine, MustCallFunctionEngine, MustNotCallFunctionEngine +from .method_call import MustCallMethodEngine, MustNotCallMethodEngine +from .node_count import CountNodeEngine +from .node_exists import MustExistNodeEngine, MustNotExistNodeEngine +from .operator import MustUseOperatorEngine + +ENGINES = { + "must_exist_node": MustExistNodeEngine(), + "must_not_exist_node": MustNotExistNodeEngine(), + "count_node": CountNodeEngine(), + "must_call_function": MustCallFunctionEngine(), + "must_not_call_function": MustNotCallFunctionEngine(), + "count_function_call": CountFunctionCallEngine(), + "must_call_method": MustCallMethodEngine(), + "must_not_call_method": MustNotCallMethodEngine(), + "must_use_operator": MustUseOperatorEngine(), +} + + +def get_engine(name: str): + return ENGINES.get(name) diff --git a/ast_checker/engines/base.py b/ast_checker/engines/base.py new file mode 100644 index 0000000..3b28da8 --- /dev/null +++ b/ast_checker/engines/base.py @@ -0,0 +1,18 @@ +class BaseEngine: + @staticmethod + def collect_nodes(node, node_type): + results = [] + if node.type == node_type: + results.append(node) + for child in node.children: + results.extend(BaseEngine.collect_nodes(child, node_type)) + return results + + @staticmethod + def has_node(node, node_type): + if node.type == node_type: + return True + return any(BaseEngine.has_node(child, node_type) for child in node.children) + + def check(self, tree, rule, language, mapping) -> list[str]: + raise NotImplementedError diff --git a/ast_checker/engines/function_call.py b/ast_checker/engines/function_call.py new file mode 100644 index 0000000..d77ce1a --- /dev/null +++ b/ast_checker/engines/function_call.py @@ -0,0 +1,47 @@ +from .base import BaseEngine + +CALL_NODE_TYPES = { + "Python3": "call", + "C": "call_expression", +} + + +class _FunctionCallBase(BaseEngine): + def _find_function_calls(self, root, func_name, language): + call_type = CALL_NODE_TYPES.get(language, "call") + calls = self.collect_nodes(root, call_type) + matches = [] + for call in calls: + func_node = call.child_by_field_name("function") + if func_node and func_node.type == "identifier" and func_node.text.decode() == func_name: + matches.append(call) + return matches + + +class MustCallFunctionEngine(_FunctionCallBase): + def check(self, tree, rule, language, mapping): + target = rule["target"] + if not self._find_function_calls(tree.root_node, target, language): + return [rule.get("message", f"必须调用 {target}()")] + return [] + + +class MustNotCallFunctionEngine(_FunctionCallBase): + def check(self, tree, rule, language, mapping): + target = rule["target"] + if self._find_function_calls(tree.root_node, target, language): + return [rule.get("message", f"不能调用 {target}()")] + return [] + + +class CountFunctionCallEngine(_FunctionCallBase): + def check(self, tree, rule, language, mapping): + target = rule["target"] + count = len(self._find_function_calls(tree.root_node, target, language)) + min_count = rule.get("min") + max_count = rule.get("max") + if min_count is not None and count < min_count: + return [rule.get("message", f"{target}() 至少调用 {min_count} 次,当前 {count} 次")] + if max_count is not None and count > max_count: + return [rule.get("message", f"{target}() 至多调用 {max_count} 次,当前 {count} 次")] + return [] diff --git a/ast_checker/engines/method_call.py b/ast_checker/engines/method_call.py new file mode 100644 index 0000000..0dfd3c8 --- /dev/null +++ b/ast_checker/engines/method_call.py @@ -0,0 +1,38 @@ +from .base import BaseEngine + +CALL_NODE_TYPES = { + "Python3": "call", + "C": "call_expression", +} + + +class _MethodCallBase(BaseEngine): + def _find_method_calls(self, root, method_name, language): + if language == "C": + return [] + call_type = CALL_NODE_TYPES.get(language, "call") + calls = self.collect_nodes(root, call_type) + matches = [] + for call in calls: + func_node = call.child_by_field_name("function") + if func_node and func_node.type == "attribute": + attr_node = func_node.child_by_field_name("attribute") + if attr_node and attr_node.text.decode() == method_name: + matches.append(call) + return matches + + +class MustCallMethodEngine(_MethodCallBase): + def check(self, tree, rule, language, mapping): + target = rule["target"] + if not self._find_method_calls(tree.root_node, target, language): + return [rule.get("message", f"必须调用 .{target}()")] + return [] + + +class MustNotCallMethodEngine(_MethodCallBase): + def check(self, tree, rule, language, mapping): + target = rule["target"] + if self._find_method_calls(tree.root_node, target, language): + return [rule.get("message", f"不能调用 .{target}()")] + return [] diff --git a/ast_checker/engines/node_count.py b/ast_checker/engines/node_count.py new file mode 100644 index 0000000..6b9b6db --- /dev/null +++ b/ast_checker/engines/node_count.py @@ -0,0 +1,16 @@ +from .base import BaseEngine + + +class CountNodeEngine(BaseEngine): + def check(self, tree, rule, language, mapping): + target = rule["target"] + node_type = mapping.get(target, target) + nodes = self.collect_nodes(tree.root_node, node_type) + count = len(nodes) + min_count = rule.get("min") + max_count = rule.get("max") + if min_count is not None and count < min_count: + return [rule.get("message", f"{target} 至少出现 {min_count} 次,当前 {count} 次")] + if max_count is not None and count > max_count: + return [rule.get("message", f"{target} 至多出现 {max_count} 次,当前 {count} 次")] + return [] diff --git a/ast_checker/engines/node_exists.py b/ast_checker/engines/node_exists.py new file mode 100644 index 0000000..b2f1b0f --- /dev/null +++ b/ast_checker/engines/node_exists.py @@ -0,0 +1,19 @@ +from .base import BaseEngine + + +class MustExistNodeEngine(BaseEngine): + def check(self, tree, rule, language, mapping): + target = rule["target"] + node_type = mapping.get(target, target) + if not self.has_node(tree.root_node, node_type): + return [rule.get("message", f"必须使用 {target}")] + return [] + + +class MustNotExistNodeEngine(BaseEngine): + def check(self, tree, rule, language, mapping): + target = rule["target"] + node_type = mapping.get(target, target) + if self.has_node(tree.root_node, node_type): + return [rule.get("message", f"不能使用 {target}")] + return [] diff --git a/ast_checker/engines/operator.py b/ast_checker/engines/operator.py new file mode 100644 index 0000000..2f8937d --- /dev/null +++ b/ast_checker/engines/operator.py @@ -0,0 +1,10 @@ +from .base import BaseEngine + + +class MustUseOperatorEngine(BaseEngine): + def check(self, tree, rule, language, mapping): + target = rule["target"] + mapped_op = mapping.get(target, target) + if not self.has_node(tree.root_node, mapped_op): + return [rule.get("message", f"必须使用 {target} 运算符")] + return []