diff --git a/backend/app/api/v1/endpoints/tts.py b/backend/app/api/v1/endpoints/tts.py new file mode 100644 index 0000000..e3462c8 --- /dev/null +++ b/backend/app/api/v1/endpoints/tts.py @@ -0,0 +1,455 @@ +# tts.py +import uuid +import websockets +import time +import fastrand +import json +import asyncio +from typing import Dict, Any, Optional as OptionalType + +from app.constants.tts import APP_ID, TOKEN, SPEAKER + +# 协议常量保持不变... +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 +FULL_CLIENT_REQUEST = 0b0001 +AUDIO_ONLY_RESPONSE = 0b1011 +FULL_SERVER_RESPONSE = 0b1001 +ERROR_INFORMATION = 0b1111 +MsgTypeFlagWithEvent = 0b100 +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +COMPRESSION_NO = 0b0000 + +# 事件类型 +EVENT_NONE = 0 +EVENT_Start_Connection = 1 +EVENT_FinishConnection = 2 +EVENT_ConnectionStarted = 50 +EVENT_StartSession = 100 +EVENT_FinishSession = 102 +EVENT_SessionStarted = 150 +EVENT_SessionFinished = 152 +EVENT_TaskRequest = 200 +EVENT_TTSSentenceEnd = 351 +EVENT_TTSResponse = 352 + + +# 所有类定义保持不变... +class Header: + def __init__(self, + protocol_version=PROTOCOL_VERSION, + header_size=DEFAULT_HEADER_SIZE, + message_type=0, + message_type_specific_flags=0, + serial_method=NO_SERIALIZATION, + compression_type=COMPRESSION_NO, + reserved_data=0): + self.header_size = header_size + self.protocol_version = protocol_version + self.message_type = message_type + self.message_type_specific_flags = message_type_specific_flags + self.serial_method = serial_method + self.compression_type = compression_type + self.reserved_data = reserved_data + + def as_bytes(self) -> bytes: + return bytes([ + (self.protocol_version << 4) | self.header_size, + (self.message_type << 4) | self.message_type_specific_flags, + (self.serial_method << 4) | self.compression_type, + self.reserved_data + ]) + + +class Optional: + def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None): + self.event = event + self.sessionId = sessionId + self.errorCode = 0 + self.connectionId = None + self.response_meta_json = None + self.sequence = sequence + + def as_bytes(self) -> bytes: + option_bytes = bytearray() + if self.event != EVENT_NONE: + option_bytes.extend(self.event.to_bytes(4, "big", signed=True)) + if self.sessionId is not None: + session_id_bytes = str.encode(self.sessionId) + size = len(session_id_bytes).to_bytes(4, "big", signed=True) + option_bytes.extend(size) + option_bytes.extend(session_id_bytes) + if self.sequence is not None: + option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True)) + return option_bytes + + +class Response: + def __init__(self, header, optional): + self.optional = optional + self.header = header + self.payload = None + self.payload_json = None + + +# 工具函数保持不变... +def gen_log_id(): + """生成logID""" + ts = int(time.time() * 1000) + r = fastrand.pcg32bounded(1 << 24) + (1 << 20) + local_ip = "00000000000000000000000000000000" + return f"02{ts}{local_ip}{r:08x}" + + +def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3', + audio_sample_rate=24000): + return str.encode(json.dumps({ + "user": {"uid": uid}, + "event": event, + "namespace": "BidirectionalTTS", + "req_params": { + "text": text, + "speaker": speaker, + "audio_params": { + "format": audio_format, + "sample_rate": audio_sample_rate, + "enable_timestamp": True, + } + } + })) + + +def read_res_content(res, offset): + content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) + offset += 4 + content = str(res[offset: offset + content_size], encoding='utf8') + offset += content_size + return content, offset + + +def read_res_payload(res, offset): + payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) + offset += 4 + payload = res[offset: offset + payload_size] + offset += payload_size + return payload, offset + + +def parser_response(res) -> Response: + if isinstance(res, str): + raise RuntimeError(res) + response = Response(Header(), Optional()) + # 解析结果 + header = response.header + num = 0b00001111 + header.protocol_version = res[0] >> 4 & num + header.header_size = res[0] & 0x0f + header.message_type = (res[1] >> 4) & num + header.message_type_specific_flags = res[1] & 0x0f + header.serial_method = res[2] >> num + header.message_compression = res[2] & 0x0f + header.reserved_data = res[3] + + offset = 4 + optional = response.optional + if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE: + # read event + if header.message_type_specific_flags == MsgTypeFlagWithEvent: + optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True) + offset += 4 + if optional.event == EVENT_NONE: + return response + # read connectionId + elif optional.event == EVENT_ConnectionStarted: + optional.connectionId, offset = read_res_content(res, offset) + elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished: + optional.sessionId, offset = read_res_content(res, offset) + optional.response_meta_json, offset = read_res_content(res, offset) + elif optional.event == EVENT_TTSResponse: + optional.sessionId, offset = read_res_content(res, offset) + response.payload, offset = read_res_payload(res, offset) + elif optional.event == EVENT_TTSSentenceEnd: + optional.sessionId, offset = read_res_content(res, offset) + response.payload_json, offset = read_res_content(res, offset) + + elif header.message_type == ERROR_INFORMATION: + optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True) + offset += 4 + response.payload, offset = read_res_payload(res, offset) + return response + + +async def send_event(ws, header, optional=None, payload=None): + full_client_request = bytearray(header) + if optional is not None: + full_client_request.extend(optional) + if payload is not None: + payload_size = len(payload).to_bytes(4, 'big', signed=True) + full_client_request.extend(payload_size) + full_client_request.extend(payload) + await ws.send(full_client_request) + + +# 修改:TTS状态管理类,添加消息ID和任务追踪 +class TTSState: + def __init__(self, message_id: str): + self.message_id = message_id + self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None + self.session_id: OptionalType[str] = None + self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务 + self.is_processing = False + + +# 全局状态管理 +class TTSManager: + def __init__(self): + # WebSocket -> 消息ID -> TTS状态 + self.connections: Dict[any, Dict[str, TTSState]] = {} + # 会话ID -> 消息ID 的映射,用于路由响应 + self.session_to_message: Dict[str, str] = {} + # 消息ID -> WebSocket 的映射 + self.message_to_websocket: Dict[str, any] = {} + + def get_connection_states(self, websocket) -> Dict[str, TTSState]: + """获取WebSocket连接的所有TTS状态""" + if websocket not in self.connections: + self.connections[websocket] = {} + return self.connections[websocket] + + def add_tts_state(self, websocket, message_id: str) -> TTSState: + """添加新的TTS状态""" + states = self.get_connection_states(websocket) + if message_id in states: + # 如果已存在,先清理旧的 + self.cleanup_message_state(websocket, message_id) + + tts_state = TTSState(message_id) + states[message_id] = tts_state + self.message_to_websocket[message_id] = websocket + return tts_state + + def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]: + """获取指定的TTS状态""" + states = self.get_connection_states(websocket) + return states.get(message_id) + + def register_session(self, session_id: str, message_id: str): + """注册会话ID和消息ID的映射""" + self.session_to_message[session_id] = message_id + + def get_message_by_session(self, session_id: str) -> OptionalType[str]: + """根据会话ID获取消息ID""" + return self.session_to_message.get(session_id) + + def get_websocket_by_message(self, message_id: str): + """根据消息ID获取WebSocket""" + return self.message_to_websocket.get(message_id) + + def cleanup_message_state(self, websocket, message_id: str): + """清理指定消息的状态""" + states = self.get_connection_states(websocket) + if message_id in states: + tts_state = states[message_id] + # 取消任务 + if tts_state.task and not tts_state.task.done(): + tts_state.task.cancel() + # 清理映射 + if tts_state.session_id and tts_state.session_id in self.session_to_message: + del self.session_to_message[tts_state.session_id] + if message_id in self.message_to_websocket: + del self.message_to_websocket[message_id] + # 删除状态 + del states[message_id] + + def cleanup_connection(self, websocket): + """清理整个连接的状态""" + if websocket in self.connections: + states = self.connections[websocket] + for message_id in list(states.keys()): + self.cleanup_message_state(websocket, message_id) + del self.connections[websocket] + + +# 全局TTS管理器实例 +tts_manager = TTSManager() + + +# 初始化独立的TTS连接 +async def create_tts_connection() -> websockets.WebSocketServerProtocol: + """创建独立的TTS连接""" + log_id = gen_log_id() + ws_header = { + "X-Api-App-Key": APP_ID, + "X-Api-Access-Key": TOKEN, + "X-Api-Resource-Id": 'volc.service_type.10029', + "X-Api-Connect-Id": str(uuid.uuid4()), + "X-Tt-Logid": log_id, + } + url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection' + volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000) + + # 启动连接 + header = Header(message_type=FULL_CLIENT_REQUEST, + message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes() + optional = Optional(event=EVENT_Start_Connection).as_bytes() + payload = str.encode("{}") + await send_event(volc_ws, header, optional, payload) + + # 等待连接确认 + raw_data = await volc_ws.recv() + res = parser_response(raw_data) + if res.optional.event != EVENT_ConnectionStarted: + raise Exception("TTS连接失败") + + return volc_ws + + +# 处理单个TTS任务 +async def process_tts_task(websocket, message_id: str, text: str): + """处理单个TTS任务(独立协程)""" + tts_state = None + try: + print(f"开始处理TTS任务 [{message_id}]: {text}") + + # 获取TTS状态 + tts_state = tts_manager.get_tts_state(websocket, message_id) + if not tts_state: + raise Exception(f"找不到TTS状态: {message_id}") + + tts_state.is_processing = True + + # 创建独立的TTS连接 + tts_state.volc_ws = await create_tts_connection() + + # 创建会话 + tts_state.session_id = uuid.uuid4().__str__().replace('-', '') + tts_manager.register_session(tts_state.session_id, message_id) + + print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}") + header = Header(message_type=FULL_CLIENT_REQUEST, + message_type_specific_flags=MsgTypeFlagWithEvent, + serial_method=JSON).as_bytes() + optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes() + payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER) + await send_event(tts_state.volc_ws, header, optional, payload) + + raw_data = await tts_state.volc_ws.recv() + res = parser_response(raw_data) + if res.optional.event != EVENT_SessionStarted: + raise Exception("TTS会话启动失败") + print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}") + + # 发送文本到TTS服务 + print(f"发送文本到TTS服务 [{message_id}]...") + header = Header(message_type=FULL_CLIENT_REQUEST, + message_type_specific_flags=MsgTypeFlagWithEvent, + serial_method=JSON).as_bytes() + optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes() + payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER) + await send_event(tts_state.volc_ws, header, optional, payload) + + # 接收TTS响应并发送到前端 + print(f"开始接收TTS响应 [{message_id}]...") + audio_count = 0 + + try: + while True: + raw_data = await asyncio.wait_for( + tts_state.volc_ws.recv(), + timeout=30 + ) + res = parser_response(raw_data) + + print(f"收到TTS事件 [{message_id}]: {res.optional.event}") + + if res.optional.event == EVENT_TTSSentenceEnd: + print(f"句子结束事件 [{message_id}] - 直接完成") + break + + elif res.optional.event == EVENT_SessionFinished: + print(f"收到会话结束事件 [{message_id}]") + break + + elif res.optional.event == EVENT_TTSResponse: + audio_count += 1 + print(f"发送音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}") + # 发送音频数据,包含消息ID + await websocket.send_json({ + "type": "tts_audio_data", + "messageId": message_id, + "audioData": res.payload.hex() # 转为hex字符串 + }) + else: + print(f"未知TTS事件 [{message_id}]: {res.optional.event}") + + except asyncio.TimeoutError: + print(f"TTS响应超时 [{message_id}],强制结束") + + # 发送完成消息 + await websocket.send_json({ + "type": "tts_audio_complete", + "messageId": message_id + }) + print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包") + + except asyncio.CancelledError: + print(f"TTS任务被取消 [{message_id}]") + await websocket.send_json({ + "type": "tts_error", + "messageId": message_id, + "message": "TTS任务被取消" + }) + except Exception as e: + print(f"TTS处理异常 [{message_id}]: {e}") + import traceback + traceback.print_exc() + await websocket.send_json({ + "type": "tts_error", + "messageId": message_id, + "message": f"TTS处理失败: {str(e)}" + }) + finally: + # 清理资源 + if tts_state: + tts_state.is_processing = False + if tts_state.volc_ws: + try: + await tts_state.volc_ws.close() + except: + pass + # 清理状态 + tts_manager.cleanup_message_state(websocket, message_id) + + +# 启动TTS文本转换 +async def handle_tts_text(websocket, message_id: str, text: str): + """启动TTS文本转换""" + # 创建新的TTS状态 + tts_state = tts_manager.add_tts_state(websocket, message_id) + + # 启动异步任务 + tts_state.task = asyncio.create_task( + process_tts_task(websocket, message_id, text) + ) + + +# 取消TTS任务 +async def handle_tts_cancel(websocket, message_id: str): + """取消TTS任务""" + tts_state = tts_manager.get_tts_state(websocket, message_id) + if tts_state and tts_state.task and not tts_state.task.done(): + tts_state.task.cancel() + await websocket.send_json({ + "type": "tts_complete", + "messageId": message_id + }) + tts_manager.cleanup_message_state(websocket, message_id) + + +# 清理连接的所有TTS资源 +async def cleanup_connection_tts(websocket): + """清理连接的所有TTS资源""" + print(f"清理连接的TTS资源...") + tts_manager.cleanup_connection(websocket) + print("TTS资源清理完成") diff --git a/backend/app/api/v1/endpoints/websocket_service.py b/backend/app/api/v1/endpoints/websocket_service.py index 0bf474d..3e683ed 100644 --- a/backend/app/api/v1/endpoints/websocket_service.py +++ b/backend/app/api/v1/endpoints/websocket_service.py @@ -1,14 +1,19 @@ +# websocket_service.py from fastapi import APIRouter, WebSocket, WebSocketDisconnect from typing import Set from aip import AipSpeech from app.constants.asr import APP_ID, API_KEY, SECRET_KEY import json +# 导入修改后的TTS模块 +from . import tts + router = APIRouter() active_connections: Set[WebSocket] = set() asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) + async def asr_buffer(buffer_data: bytes) -> str: result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537}) if result.get('err_msg') == 'success.': @@ -16,6 +21,7 @@ async def asr_buffer(buffer_data: bytes) -> str: else: return '语音转换失败' + async def broadcast_online_count(): data = {"online_count": len(active_connections), 'type': 'count'} to_remove = set() @@ -27,12 +33,14 @@ async def broadcast_online_count(): for ws in to_remove: active_connections.remove(ws) + @router.websocket("/websocket") async def websocket_online_count(websocket: WebSocket): await websocket.accept() active_connections.add(websocket) await broadcast_online_count() temp_buffer = bytes() + try: while True: message = await websocket.receive() @@ -45,15 +53,59 @@ async def websocket_online_count(websocket: WebSocket): except Exception: continue msg_type = data.get("type") + if msg_type == "ping": await websocket.send_json({"online_count": len(active_connections), "type": "count"}) + elif msg_type == "asr_end": asr_text = await asr_buffer(temp_buffer) await websocket.send_json({"type": "asr_result", "result": asr_text}) temp_buffer = bytes() + + # 修改:TTS处理支持消息ID + elif msg_type == "tts_text": + message_id = data.get("messageId") + text = data.get("text", "") + + if not message_id: + await websocket.send_json({ + "type": "tts_error", + "message": "缺少messageId参数" + }) + continue + + print(f"收到TTS文本请求 [{message_id}]: {text}") + try: + await tts.handle_tts_text(websocket, message_id, text) + except Exception as e: + print(f"TTS文本处理异常 [{message_id}]: {e}") + await websocket.send_json({ + "type": "tts_error", + "messageId": message_id, + "message": f"TTS处理失败: {str(e)}" + }) + + elif msg_type == "tts_cancel": + message_id = data.get("messageId") + if message_id: + print(f"收到TTS取消请求 [{message_id}]") + try: + await tts.handle_tts_cancel(websocket, message_id) + except Exception as e: + print(f"TTS取消处理异常 [{message_id}]: {e}") + except WebSocketDisconnect: - active_connections.remove(websocket) - await broadcast_online_count() - except Exception: - active_connections.remove(websocket) + pass + except Exception as e: + print(f"WebSocket异常: {e}") + finally: + # 清理资源 + active_connections.discard(websocket) + + # 清理所有TTS资源 + try: + await tts.cleanup_connection_tts(websocket) + except: + pass + await broadcast_online_count() diff --git a/backend/app/constants/model_data.py b/backend/app/constants/model_data.py index c2761bd..cf1884c 100644 --- a/backend/app/constants/model_data.py +++ b/backend/app/constants/model_data.py @@ -19,7 +19,8 @@ MODEL_DATA = [ { "vendor": "Anthropic", "models": [ - {"model_id": "claude-sonnet-4-thinking", "model_name": "Claude Sonnet 4 thinking", "model_type": "reasoning"}, + {"model_id": "claude-sonnet-4-thinking", "model_name": "Claude Sonnet 4 thinking", + "model_type": "reasoning"}, {"model_id": "claude-sonnet-4", "model_name": "Claude Sonnet 4", "model_type": "text"}, ] }, @@ -27,6 +28,7 @@ MODEL_DATA = [ "vendor": "硅基流动", "models": [ {"model_id": "deepseek-v3", "model_name": "DeepSeek V3", "model_type": "text"}, + {"model_id": "deepseek-r1", "model_name": "DeepSeek R1", "model_type": "reasoning"}, ] } ] diff --git a/backend/app/constants/tts.py b/backend/app/constants/tts.py new file mode 100644 index 0000000..8e8c3aa --- /dev/null +++ b/backend/app/constants/tts.py @@ -0,0 +1,7 @@ +# APP_ID = '1142362958' +# TOKEN = 'O-a4JkyLFrYkME9no11DxFkOY-UnAoFF' +# SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts' + +APP_ID = '2138450044' +TOKEN = 'V04_QumeQZhJrQ_In1Z0VBQm7n0ttMNO' +SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts' \ No newline at end of file diff --git a/web/.prettierignore b/web/.prettierignore new file mode 100644 index 0000000..53c6465 --- /dev/null +++ b/web/.prettierignore @@ -0,0 +1,3 @@ +# .prettierignore +auto-imports.d.ts +components.d.ts diff --git a/web/.prettierrc.json b/web/.prettierrc.json new file mode 100644 index 0000000..c4b2197 --- /dev/null +++ b/web/.prettierrc.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://json.schemastore.org/prettierrc", + "tabWidth": 2, + "singleQuote": false, + "printWidth": 80, + "trailingComma": "none", + "ignorePath": ".prettierignore" +} diff --git a/web/components.d.ts b/web/components.d.ts index c9dec3e..ddc6cfc 100644 --- a/web/components.d.ts +++ b/web/components.d.ts @@ -7,16 +7,23 @@ export {} /* prettier-ignore */ declare module 'vue' { export interface GlobalComponents { + Avatar: typeof import('./src/components/avatar.vue')['default'] Markdown: typeof import('./src/components/markdown.vue')['default'] NButton: typeof import('naive-ui')['NButton'] + NCollapse: typeof import('naive-ui')['NCollapse'] + NCollapseItem: typeof import('naive-ui')['NCollapseItem'] NConfigProvider: typeof import('naive-ui')['NConfigProvider'] + NDivider: typeof import('naive-ui')['NDivider'] + NImage: typeof import('naive-ui')['NImage'] NInput: typeof import('naive-ui')['NInput'] NMessageProvider: typeof import('naive-ui')['NMessageProvider'] NPopconfirm: typeof import('naive-ui')['NPopconfirm'] + NPopover: typeof import('naive-ui')['NPopover'] NScrollbar: typeof import('naive-ui')['NScrollbar'] NSelect: typeof import('naive-ui')['NSelect'] NTag: typeof import('naive-ui')['NTag'] RouterLink: typeof import('vue-router')['RouterLink'] RouterView: typeof import('vue-router')['RouterView'] + Tts: typeof import('./src/components/tts.vue')['default'] } } diff --git a/web/eslint.config.js b/web/eslint.config.js index 9e5e7f3..be6c010 100644 --- a/web/eslint.config.js +++ b/web/eslint.config.js @@ -1,39 +1,37 @@ -import antfu from "@antfu/eslint-config" +import antfu from "@antfu/eslint-config"; -export default antfu( - { - formatters: { - /** - * Format CSS, LESS, SCSS files, also the ` diff --git a/web/src/components/tts.vue b/web/src/components/tts.vue new file mode 100644 index 0000000..b7000e6 --- /dev/null +++ b/web/src/components/tts.vue @@ -0,0 +1,81 @@ + + + diff --git a/web/src/interfaces/chat_service.ts b/web/src/interfaces/chat_service.ts index d04f064..94e2f51 100644 --- a/web/src/interfaces/chat_service.ts +++ b/web/src/interfaces/chat_service.ts @@ -1,24 +1,33 @@ export interface IChatWithLLMRequest { - messages: Message[] + messages: Message[]; /** * 要使用的模型的 ID */ - model: string + model: string; } export interface Message { - content?: string - role?: string - [property: string]: any + content?: string; + thinking?: string; + role?: string; + usage?: UsageInfo; + id?: string; + [property: string]: any; } export interface ModelInfo { - model_id: string - model_name: string - model_type: string + model_id: string; + model_name: string; + model_type: string; } export interface ModelListInfo { - vendor: string - models: ModelInfo[] + vendor: string; + models: ModelInfo[]; +} + +export interface UsageInfo { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; } diff --git a/web/src/interfaces/index.ts b/web/src/interfaces/index.ts index 6c561ba..74c48cb 100644 --- a/web/src/interfaces/index.ts +++ b/web/src/interfaces/index.ts @@ -1,9 +1,9 @@ export interface ICommonResponse { - code: number - msg: string - data: T + code: number; + msg: string; + data: T; } -export type IMsgOnlyResponse = ICommonResponse<{ msg: string }> +export type IMsgOnlyResponse = ICommonResponse<{ msg: string }>; -export * from "./chat_service" +export * from "./chat_service"; diff --git a/web/src/layouts/BasicLayout.vue b/web/src/layouts/BasicLayout.vue index 1534a98..d2c6c0d 100644 --- a/web/src/layouts/BasicLayout.vue +++ b/web/src/layouts/BasicLayout.vue @@ -1,25 +1,49 @@