Files
OnlineJudge/utils/websocket.py
2025-10-07 17:03:14 +08:00

76 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
WebSocket utility functions for pushing real-time updates
"""
import logging
from channels.layers import get_channel_layer
from asgiref.sync import async_to_sync
logger = logging.getLogger(__name__)
def push_submission_update(submission_id: str, user_id: int, data: dict):
"""
推送提交状态更新到指定用户的 WebSocket 连接
Args:
submission_id: 提交 ID
user_id: 用户 ID
data: 要发送的数据,应该包含 type, submission_id, result 等字段
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push submission update")
return
# 构建组名,与 SubmissionConsumer 中的组名一致
group_name = f"submission_user_{user_id}"
try:
# 向指定用户组发送消息
# type 字段对应 consumer 中的方法名submission_update
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "submission_update", # 对应 SubmissionConsumer.submission_update 方法
"data": data,
}
)
logger.info(f"Pushed submission update: submission_id={submission_id}, user_id={user_id}, status={data.get('status')}")
except Exception as e:
logger.error(f"Failed to push submission update: submission_id={submission_id}, user_id={user_id}, error={str(e)}")
def push_to_user(user_id: int, message_type: str, data: dict):
"""
向指定用户推送自定义消息
Args:
user_id: 用户 ID
message_type: 消息类型
data: 消息数据
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push message")
return
group_name = f"submission_user_{user_id}"
try:
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "submission_update",
"data": {
"type": message_type,
**data
},
}
)
logger.info(f"Pushed message to user {user_id}: type={message_type}")
except Exception as e:
logger.error(f"Failed to push message to user {user_id}: error={str(e)}")