From c6b368dc825a0aab7cc7e532ce9349f9849d270a Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Mon, 25 May 2026 22:24:53 -0600 Subject: [PATCH] update --- ast_checker/checker.py | 14 +++++--- ast_checker/engines/base.py | 3 ++ ast_checker/engines/function_call.py | 49 +++++++++++++++++++++------- ast_checker/engines/method_call.py | 22 +++++++++---- ast_checker/engines/node_count.py | 31 ++++++++++++++---- ast_checker/engines/node_exists.py | 22 +++++++++---- ast_checker/engines/operator.py | 11 +++++-- judge/dispatcher.py | 4 +-- 8 files changed, 116 insertions(+), 40 deletions(-) diff --git a/ast_checker/checker.py b/ast_checker/checker.py index 3d8db9c..059d0b0 100644 --- a/ast_checker/checker.py +++ b/ast_checker/checker.py @@ -4,7 +4,7 @@ from .engines import get_engine from .mappings import get_language, get_mapping -def check_ast(code: str, language: str, rules: list[dict]) -> tuple[bool, list[str]]: +def check_ast(code: str, language: str, rules: list[dict]) -> tuple[bool, list[dict]]: if not rules: return True, [] @@ -20,12 +20,16 @@ def check_ast(code: str, language: str, rules: list[dict]) -> tuple[bool, list[s except Exception: return True, [] - errors = [] + results = [] + all_passed = True for rule in rules: engine = get_engine(rule.get("engine", "")) if engine is None: continue - rule_errors = engine.check(tree, rule, language, mapping) - errors.extend(rule_errors) + errors = engine.check(tree, rule, language, mapping) + passed = len(errors) == 0 + if not passed: + all_passed = False + results.append({"description": engine.describe(rule, language, mapping), "passed": passed}) - return len(errors) == 0, errors + return all_passed, results diff --git a/ast_checker/engines/base.py b/ast_checker/engines/base.py index 3b28da8..2509f13 100644 --- a/ast_checker/engines/base.py +++ b/ast_checker/engines/base.py @@ -16,3 +16,6 @@ class BaseEngine: def check(self, tree, rule, language, mapping) -> list[str]: raise NotImplementedError + + def describe(self, rule, language, mapping) -> str: + raise NotImplementedError diff --git a/ast_checker/engines/function_call.py b/ast_checker/engines/function_call.py index d77ce1a..09bd020 100644 --- a/ast_checker/engines/function_call.py +++ b/ast_checker/engines/function_call.py @@ -19,29 +19,56 @@ class _FunctionCallBase(BaseEngine): class MustCallFunctionEngine(_FunctionCallBase): + def _message(self, rule): + return rule.get("message") or f"必须调用 {rule['target']}()" + def check(self, tree, rule, language, mapping): - target = rule["target"] - if not self._find_function_calls(tree.root_node, target, language): - return [rule.get("message", f"必须调用 {target}()")] + if not self._find_function_calls(tree.root_node, rule["target"], language): + return [self._message(rule)] return [] + def describe(self, rule, language, mapping): + return self._message(rule) + class MustNotCallFunctionEngine(_FunctionCallBase): + def _message(self, rule): + return rule.get("message") or f"不能调用 {rule['target']}()" + def check(self, tree, rule, language, mapping): - target = rule["target"] - if self._find_function_calls(tree.root_node, target, language): - return [rule.get("message", f"不能调用 {target}()")] + if self._find_function_calls(tree.root_node, rule["target"], language): + return [self._message(rule)] return [] + def describe(self, rule, language, mapping): + return self._message(rule) + class CountFunctionCallEngine(_FunctionCallBase): - def check(self, tree, rule, language, mapping): + def _message(self, rule, count): target = rule["target"] - count = len(self._find_function_calls(tree.root_node, target, language)) min_count = rule.get("min") max_count = rule.get("max") if min_count is not None and count < min_count: - return [rule.get("message", f"{target}() 至少调用 {min_count} 次,当前 {count} 次")] + return rule.get("message") or f"{target}() 至少调用 {min_count} 次,当前 {count} 次" if max_count is not None and count > max_count: - return [rule.get("message", f"{target}() 至多调用 {max_count} 次,当前 {count} 次")] - return [] + return rule.get("message") or f"{target}() 至多调用 {max_count} 次,当前 {count} 次" + return None + + def check(self, tree, rule, language, mapping): + count = len(self._find_function_calls(tree.root_node, rule["target"], language)) + msg = self._message(rule, count) + return [msg] if msg else [] + + def describe(self, rule, language, mapping): + target = rule["target"] + min_count = rule.get("min") + max_count = rule.get("max") + if rule.get("message"): + return rule["message"] + parts = [] + if min_count is not None: + parts.append(f"至少 {min_count} 次") + if max_count is not None: + parts.append(f"至多 {max_count} 次") + return f"{target}() " + "、".join(parts) diff --git a/ast_checker/engines/method_call.py b/ast_checker/engines/method_call.py index 0dfd3c8..fd9c47a 100644 --- a/ast_checker/engines/method_call.py +++ b/ast_checker/engines/method_call.py @@ -23,16 +23,26 @@ class _MethodCallBase(BaseEngine): class MustCallMethodEngine(_MethodCallBase): + def _message(self, rule): + return rule.get("message") or f"必须调用 .{rule['target']}()" + def check(self, tree, rule, language, mapping): - target = rule["target"] - if not self._find_method_calls(tree.root_node, target, language): - return [rule.get("message", f"必须调用 .{target}()")] + if not self._find_method_calls(tree.root_node, rule["target"], language): + return [self._message(rule)] return [] + def describe(self, rule, language, mapping): + return self._message(rule) + class MustNotCallMethodEngine(_MethodCallBase): + def _message(self, rule): + return rule.get("message") or f"不能调用 .{rule['target']}()" + def check(self, tree, rule, language, mapping): - target = rule["target"] - if self._find_method_calls(tree.root_node, target, language): - return [rule.get("message", f"不能调用 .{target}()")] + if self._find_method_calls(tree.root_node, rule["target"], language): + return [self._message(rule)] return [] + + def describe(self, rule, language, mapping): + return self._message(rule) diff --git a/ast_checker/engines/node_count.py b/ast_checker/engines/node_count.py index 6b9b6db..c76e106 100644 --- a/ast_checker/engines/node_count.py +++ b/ast_checker/engines/node_count.py @@ -2,15 +2,32 @@ from .base import BaseEngine class CountNodeEngine(BaseEngine): - def check(self, tree, rule, language, mapping): + def _message(self, rule, count): target = rule["target"] - node_type = mapping.get(target, target) - nodes = self.collect_nodes(tree.root_node, node_type) - count = len(nodes) min_count = rule.get("min") max_count = rule.get("max") if min_count is not None and count < min_count: - return [rule.get("message", f"{target} 至少出现 {min_count} 次,当前 {count} 次")] + return rule.get("message") or f"{target} 至少出现 {min_count} 次,当前 {count} 次" if max_count is not None and count > max_count: - return [rule.get("message", f"{target} 至多出现 {max_count} 次,当前 {count} 次")] - return [] + return rule.get("message") or f"{target} 至多出现 {max_count} 次,当前 {count} 次" + return None + + def check(self, tree, rule, language, mapping): + target = rule["target"] + node_type = mapping.get(target, target) + count = len(self.collect_nodes(tree.root_node, node_type)) + msg = self._message(rule, count) + return [msg] if msg else [] + + def describe(self, rule, language, mapping): + target = rule["target"] + min_count = rule.get("min") + max_count = rule.get("max") + if rule.get("message"): + return rule["message"] + parts = [] + if min_count is not None: + parts.append(f"至少 {min_count} 次") + if max_count is not None: + parts.append(f"至多 {max_count} 次") + return f"{target} " + "、".join(parts) diff --git a/ast_checker/engines/node_exists.py b/ast_checker/engines/node_exists.py index b2f1b0f..eca4119 100644 --- a/ast_checker/engines/node_exists.py +++ b/ast_checker/engines/node_exists.py @@ -2,18 +2,28 @@ from .base import BaseEngine class MustExistNodeEngine(BaseEngine): + def _message(self, rule): + return rule.get("message") or f"必须使用 {rule['target']}" + def check(self, tree, rule, language, mapping): - target = rule["target"] - node_type = mapping.get(target, target) + node_type = mapping.get(rule["target"], rule["target"]) if not self.has_node(tree.root_node, node_type): - return [rule.get("message", f"必须使用 {target}")] + return [self._message(rule)] return [] + def describe(self, rule, language, mapping): + return self._message(rule) + class MustNotExistNodeEngine(BaseEngine): + def _message(self, rule): + return rule.get("message") or f"不能使用 {rule['target']}" + def check(self, tree, rule, language, mapping): - target = rule["target"] - node_type = mapping.get(target, target) + node_type = mapping.get(rule["target"], rule["target"]) if self.has_node(tree.root_node, node_type): - return [rule.get("message", f"不能使用 {target}")] + return [self._message(rule)] return [] + + def describe(self, rule, language, mapping): + return self._message(rule) diff --git a/ast_checker/engines/operator.py b/ast_checker/engines/operator.py index 2f8937d..740ba84 100644 --- a/ast_checker/engines/operator.py +++ b/ast_checker/engines/operator.py @@ -2,9 +2,14 @@ from .base import BaseEngine class MustUseOperatorEngine(BaseEngine): + def _message(self, rule): + return rule.get("message") or f"必须使用 {rule['target']} 运算符" + def check(self, tree, rule, language, mapping): - target = rule["target"] - mapped_op = mapping.get(target, target) + mapped_op = mapping.get(rule["target"], rule["target"]) if not self.has_node(tree.root_node, mapped_op): - return [rule.get("message", f"必须使用 {target} 运算符")] + return [self._message(rule)] return [] + + def describe(self, rule, language, mapping): + return self._message(rule) diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 0b8fba5..e37cfed 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -213,10 +213,10 @@ class JudgeDispatcher(DispatcherBase): if ast_rules and language in ast_rules: from ast_checker.checker import check_ast - passed, errors = check_ast(self.submission.code, language, ast_rules[language]) + passed, results = check_ast(self.submission.code, language, ast_rules[language]) if not passed: self.submission.result = JudgeStatus.AST_CHECK_FAILED - self.submission.statistic_info["err_info"] = "\n".join(errors) + self.submission.statistic_info["ast_results"] = results self.submission.save(update_fields=["result", "info", "statistic_info"]) # 推送判题完成状态