feat: add AST checker engine framework with 9 Phase 1 engines
This commit is contained in:
21
ast_checker/engines/__init__.py
Normal file
21
ast_checker/engines/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from .function_call import CountFunctionCallEngine, MustCallFunctionEngine, MustNotCallFunctionEngine
|
||||||
|
from .method_call import MustCallMethodEngine, MustNotCallMethodEngine
|
||||||
|
from .node_count import CountNodeEngine
|
||||||
|
from .node_exists import MustExistNodeEngine, MustNotExistNodeEngine
|
||||||
|
from .operator import MustUseOperatorEngine
|
||||||
|
|
||||||
|
ENGINES = {
|
||||||
|
"must_exist_node": MustExistNodeEngine(),
|
||||||
|
"must_not_exist_node": MustNotExistNodeEngine(),
|
||||||
|
"count_node": CountNodeEngine(),
|
||||||
|
"must_call_function": MustCallFunctionEngine(),
|
||||||
|
"must_not_call_function": MustNotCallFunctionEngine(),
|
||||||
|
"count_function_call": CountFunctionCallEngine(),
|
||||||
|
"must_call_method": MustCallMethodEngine(),
|
||||||
|
"must_not_call_method": MustNotCallMethodEngine(),
|
||||||
|
"must_use_operator": MustUseOperatorEngine(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine(name: str):
|
||||||
|
return ENGINES.get(name)
|
||||||
18
ast_checker/engines/base.py
Normal file
18
ast_checker/engines/base.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
class BaseEngine:
|
||||||
|
@staticmethod
|
||||||
|
def collect_nodes(node, node_type):
|
||||||
|
results = []
|
||||||
|
if node.type == node_type:
|
||||||
|
results.append(node)
|
||||||
|
for child in node.children:
|
||||||
|
results.extend(BaseEngine.collect_nodes(child, node_type))
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_node(node, node_type):
|
||||||
|
if node.type == node_type:
|
||||||
|
return True
|
||||||
|
return any(BaseEngine.has_node(child, node_type) for child in node.children)
|
||||||
|
|
||||||
|
def check(self, tree, rule, language, mapping) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
47
ast_checker/engines/function_call.py
Normal file
47
ast_checker/engines/function_call.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
CALL_NODE_TYPES = {
|
||||||
|
"Python3": "call",
|
||||||
|
"C": "call_expression",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _FunctionCallBase(BaseEngine):
|
||||||
|
def _find_function_calls(self, root, func_name, language):
|
||||||
|
call_type = CALL_NODE_TYPES.get(language, "call")
|
||||||
|
calls = self.collect_nodes(root, call_type)
|
||||||
|
matches = []
|
||||||
|
for call in calls:
|
||||||
|
func_node = call.child_by_field_name("function")
|
||||||
|
if func_node and func_node.type == "identifier" and func_node.text.decode() == func_name:
|
||||||
|
matches.append(call)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
class MustCallFunctionEngine(_FunctionCallBase):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
if not self._find_function_calls(tree.root_node, target, language):
|
||||||
|
return [rule.get("message", f"必须调用 {target}()")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MustNotCallFunctionEngine(_FunctionCallBase):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
if self._find_function_calls(tree.root_node, target, language):
|
||||||
|
return [rule.get("message", f"不能调用 {target}()")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class CountFunctionCallEngine(_FunctionCallBase):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
count = len(self._find_function_calls(tree.root_node, target, language))
|
||||||
|
min_count = rule.get("min")
|
||||||
|
max_count = rule.get("max")
|
||||||
|
if min_count is not None and count < min_count:
|
||||||
|
return [rule.get("message", f"{target}() 至少调用 {min_count} 次,当前 {count} 次")]
|
||||||
|
if max_count is not None and count > max_count:
|
||||||
|
return [rule.get("message", f"{target}() 至多调用 {max_count} 次,当前 {count} 次")]
|
||||||
|
return []
|
||||||
38
ast_checker/engines/method_call.py
Normal file
38
ast_checker/engines/method_call.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
CALL_NODE_TYPES = {
|
||||||
|
"Python3": "call",
|
||||||
|
"C": "call_expression",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _MethodCallBase(BaseEngine):
|
||||||
|
def _find_method_calls(self, root, method_name, language):
|
||||||
|
if language == "C":
|
||||||
|
return []
|
||||||
|
call_type = CALL_NODE_TYPES.get(language, "call")
|
||||||
|
calls = self.collect_nodes(root, call_type)
|
||||||
|
matches = []
|
||||||
|
for call in calls:
|
||||||
|
func_node = call.child_by_field_name("function")
|
||||||
|
if func_node and func_node.type == "attribute":
|
||||||
|
attr_node = func_node.child_by_field_name("attribute")
|
||||||
|
if attr_node and attr_node.text.decode() == method_name:
|
||||||
|
matches.append(call)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
class MustCallMethodEngine(_MethodCallBase):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
if not self._find_method_calls(tree.root_node, target, language):
|
||||||
|
return [rule.get("message", f"必须调用 .{target}()")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MustNotCallMethodEngine(_MethodCallBase):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
if self._find_method_calls(tree.root_node, target, language):
|
||||||
|
return [rule.get("message", f"不能调用 .{target}()")]
|
||||||
|
return []
|
||||||
16
ast_checker/engines/node_count.py
Normal file
16
ast_checker/engines/node_count.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
|
class CountNodeEngine(BaseEngine):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
node_type = mapping.get(target, target)
|
||||||
|
nodes = self.collect_nodes(tree.root_node, node_type)
|
||||||
|
count = len(nodes)
|
||||||
|
min_count = rule.get("min")
|
||||||
|
max_count = rule.get("max")
|
||||||
|
if min_count is not None and count < min_count:
|
||||||
|
return [rule.get("message", f"{target} 至少出现 {min_count} 次,当前 {count} 次")]
|
||||||
|
if max_count is not None and count > max_count:
|
||||||
|
return [rule.get("message", f"{target} 至多出现 {max_count} 次,当前 {count} 次")]
|
||||||
|
return []
|
||||||
19
ast_checker/engines/node_exists.py
Normal file
19
ast_checker/engines/node_exists.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
|
class MustExistNodeEngine(BaseEngine):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
node_type = mapping.get(target, target)
|
||||||
|
if not self.has_node(tree.root_node, node_type):
|
||||||
|
return [rule.get("message", f"必须使用 {target}")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MustNotExistNodeEngine(BaseEngine):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
node_type = mapping.get(target, target)
|
||||||
|
if self.has_node(tree.root_node, node_type):
|
||||||
|
return [rule.get("message", f"不能使用 {target}")]
|
||||||
|
return []
|
||||||
10
ast_checker/engines/operator.py
Normal file
10
ast_checker/engines/operator.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
|
class MustUseOperatorEngine(BaseEngine):
|
||||||
|
def check(self, tree, rule, language, mapping):
|
||||||
|
target = rule["target"]
|
||||||
|
mapped_op = mapping.get(target, target)
|
||||||
|
if not self.has_node(tree.root_node, mapped_op):
|
||||||
|
return [rule.get("message", f"必须使用 {target} 运算符")]
|
||||||
|
return []
|
||||||
Reference in New Issue
Block a user