Files
Practical_Training_Assignment/backend/app/api/v1/endpoints/websocket_service.py
2025-07-01 01:27:29 +08:00

121 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# websocket_service.py
import uuid
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
from aip import AipSpeech
from app.constants.asr import APP_ID, API_KEY, SECRET_KEY
import json
from . import tts
from .voice_conversation import process_voice_conversation
router = APIRouter()
active_connections: Set[WebSocket] = set()
asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
async def asr_buffer(buffer_data: bytes) -> str:
result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537})
if result.get('err_msg') == 'success.':
return result.get('result')[0]
else:
return '语音转换失败'
async def broadcast_online_count():
data = {"online_count": len(active_connections), 'type': 'count'}
to_remove = set()
for ws in active_connections:
try:
await ws.send_json(data)
except Exception:
to_remove.add(ws)
for ws in to_remove:
active_connections.remove(ws)
@router.websocket("/websocket")
async def websocket_online_count(websocket: WebSocket):
await websocket.accept()
active_connections.add(websocket)
await broadcast_online_count()
temp_buffer = bytes()
try:
while True:
message = await websocket.receive()
if message.get("type") == "websocket.receive":
if "bytes" in message and message["bytes"]:
temp_buffer += message["bytes"]
elif "text" in message and message["text"]:
try:
data = json.loads(message["text"])
except Exception:
continue
msg_type = data.get("type")
if msg_type == "ping":
await websocket.send_json({"online_count": len(active_connections), "type": "count"})
elif msg_type == "asr_end":
asr_text = await asr_buffer(temp_buffer)
# 从data中获取messageId如果不存在则生成一个新的ID
message_id = data.get("messageId", "voice_" + str(uuid.uuid4()))
if data.get("voiceConversation"):
speaker = data.get("speaker")
await process_voice_conversation(websocket, asr_text, message_id, speaker)
else:
await websocket.send_json({"type": "asr_result", "result": asr_text})
temp_buffer = bytes()
# TTS处理
elif msg_type == "tts_text":
message_id = data.get("messageId")
text = data.get("text", "")
speaker = data.get("speaker")
if not message_id:
await websocket.send_json({
"type": "tts_error",
"message": "缺少messageId参数"
})
continue
print(f"收到TTS文本请求 [{message_id}]: {text}")
try:
await tts.handle_tts_text(websocket, message_id, text, speaker)
except Exception as e:
print(f"TTS文本处理异常 [{message_id}]: {e}")
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": f"TTS处理失败: {str(e)}"
})
elif msg_type == "tts_cancel":
message_id = data.get("messageId")
if message_id:
print(f"收到TTS取消请求 [{message_id}]")
try:
await tts.handle_tts_cancel(websocket, message_id)
except Exception as e:
print(f"TTS取消处理异常 [{message_id}]: {e}")
except WebSocketDisconnect:
pass
except Exception as e:
print(f"WebSocket异常: {e}")
finally:
# 清理资源
active_connections.discard(websocket)
# 清理所有TTS资源
try:
await tts.cleanup_connection_tts(websocket)
except:
pass
await broadcast_online_count()