From 8d40a7b2f02662a6ecfc23fa2da6a039d7f2f58c Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Sun, 14 Jun 2026 06:57:55 -0600 Subject: [PATCH] feat: add POST /format endpoint --- main.py | 19 ++++++++++++++++++- schemas.py | 13 +++++++++++++ test_main.py | 18 ++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 test_main.py diff --git a/main.py b/main.py index ec4e2b8..3eb77fa 100644 --- a/main.py +++ b/main.py @@ -5,9 +5,16 @@ from fastapi.responses import StreamingResponse import os import json from openai import OpenAI -from schemas import PresetCodeCreate, AIAnalysisRequest, DebugRequest +from schemas import ( + PresetCodeCreate, + AIAnalysisRequest, + DebugRequest, + FormatRequest, + FormatResponse, +) from database import DatabaseService from pg_logger import exec_script_str_local +from formatter import format_code, FormatError from dotenv import load_dotenv @@ -139,6 +146,16 @@ async def debug(request: DebugRequest): return {"data": data} +@app.post("/format", response_model=FormatResponse) +async def format_code_endpoint(request: FormatRequest) -> FormatResponse: + """格式化代码""" + try: + formatted = format_code(request.code, request.language) + except FormatError as e: + raise HTTPException(status_code=400, detail=str(e)) + return FormatResponse(code=formatted) + + if __name__ == "__main__": import uvicorn diff --git a/schemas.py b/schemas.py index dab5bea..e43b784 100644 --- a/schemas.py +++ b/schemas.py @@ -32,3 +32,16 @@ class DebugRequest(BaseModel): """调试请求模式,用于调试 Python 代码""" code: str inputs: List[str] + + +class FormatRequest(BaseModel): + """格式化代码的请求模式""" + + code: str + language: str + + +class FormatResponse(BaseModel): + """格式化代码的响应模式""" + + code: str diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..f1b59d9 --- /dev/null +++ b/test_main.py @@ -0,0 +1,18 @@ +from fastapi.testclient import TestClient + +from main import app + +client = TestClient(app) + + +def test_format_endpoint_formats_python_code(): + response = client.post("/format", json={"code": "x=1\n", "language": "python"}) + assert response.status_code == 200 + assert response.json() == {"code": "x = 1\n"} + + +def test_format_endpoint_returns_400_on_syntax_error(): + response = client.post( + "/format", json={"code": "def foo(:\n", "language": "python"} + ) + assert response.status_code == 400