# tts.py import uuid import websockets import time import fastrand import json import asyncio from typing import Dict, Any, Optional as OptionalType from app.constants.tts import APP_ID, TOKEN, SPEAKER # 协议常量保持不变... PROTOCOL_VERSION = 0b0001 DEFAULT_HEADER_SIZE = 0b0001 FULL_CLIENT_REQUEST = 0b0001 AUDIO_ONLY_RESPONSE = 0b1011 FULL_SERVER_RESPONSE = 0b1001 ERROR_INFORMATION = 0b1111 MsgTypeFlagWithEvent = 0b100 NO_SERIALIZATION = 0b0000 JSON = 0b0001 COMPRESSION_NO = 0b0000 # 事件类型 EVENT_NONE = 0 EVENT_Start_Connection = 1 EVENT_FinishConnection = 2 EVENT_ConnectionStarted = 50 EVENT_StartSession = 100 EVENT_FinishSession = 102 EVENT_SessionStarted = 150 EVENT_SessionFinished = 152 EVENT_TaskRequest = 200 EVENT_TTSSentenceEnd = 351 EVENT_TTSResponse = 352 # 所有类定义保持不变... class Header: def __init__(self, protocol_version=PROTOCOL_VERSION, header_size=DEFAULT_HEADER_SIZE, message_type=0, message_type_specific_flags=0, serial_method=NO_SERIALIZATION, compression_type=COMPRESSION_NO, reserved_data=0): self.header_size = header_size self.protocol_version = protocol_version self.message_type = message_type self.message_type_specific_flags = message_type_specific_flags self.serial_method = serial_method self.compression_type = compression_type self.reserved_data = reserved_data def as_bytes(self) -> bytes: return bytes([ (self.protocol_version << 4) | self.header_size, (self.message_type << 4) | self.message_type_specific_flags, (self.serial_method << 4) | self.compression_type, self.reserved_data ]) class Optional: def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None): self.event = event self.sessionId = sessionId self.errorCode = 0 self.connectionId = None self.response_meta_json = None self.sequence = sequence def as_bytes(self) -> bytes: option_bytes = bytearray() if self.event != EVENT_NONE: option_bytes.extend(self.event.to_bytes(4, "big", signed=True)) if self.sessionId is not None: session_id_bytes = str.encode(self.sessionId) size = len(session_id_bytes).to_bytes(4, "big", signed=True) option_bytes.extend(size) option_bytes.extend(session_id_bytes) if self.sequence is not None: option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True)) return option_bytes class Response: def __init__(self, header, optional): self.optional = optional self.header = header self.payload = None self.payload_json = None # 工具函数保持不变... def gen_log_id(): """生成logID""" ts = int(time.time() * 1000) r = fastrand.pcg32bounded(1 << 24) + (1 << 20) local_ip = "00000000000000000000000000000000" return f"02{ts}{local_ip}{r:08x}" def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3', audio_sample_rate=24000): return str.encode(json.dumps({ "user": {"uid": uid}, "event": event, "namespace": "BidirectionalTTS", "req_params": { "text": text, "speaker": speaker, "audio_params": { "format": audio_format, "sample_rate": audio_sample_rate, "enable_timestamp": True, } } })) def read_res_content(res, offset): content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) offset += 4 content = str(res[offset: offset + content_size], encoding='utf8') offset += content_size return content, offset def read_res_payload(res, offset): payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) offset += 4 payload = res[offset: offset + payload_size] offset += payload_size return payload, offset def parser_response(res) -> Response: if isinstance(res, str): raise RuntimeError(res) response = Response(Header(), Optional()) # 解析结果 header = response.header num = 0b00001111 header.protocol_version = res[0] >> 4 & num header.header_size = res[0] & 0x0f header.message_type = (res[1] >> 4) & num header.message_type_specific_flags = res[1] & 0x0f header.serial_method = res[2] >> num header.message_compression = res[2] & 0x0f header.reserved_data = res[3] offset = 4 optional = response.optional if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE: # read event if header.message_type_specific_flags == MsgTypeFlagWithEvent: optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True) offset += 4 if optional.event == EVENT_NONE: return response # read connectionId elif optional.event == EVENT_ConnectionStarted: optional.connectionId, offset = read_res_content(res, offset) elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished: optional.sessionId, offset = read_res_content(res, offset) optional.response_meta_json, offset = read_res_content(res, offset) elif optional.event == EVENT_TTSResponse: optional.sessionId, offset = read_res_content(res, offset) response.payload, offset = read_res_payload(res, offset) elif optional.event == EVENT_TTSSentenceEnd: optional.sessionId, offset = read_res_content(res, offset) response.payload_json, offset = read_res_content(res, offset) elif header.message_type == ERROR_INFORMATION: optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True) offset += 4 response.payload, offset = read_res_payload(res, offset) return response async def send_event(ws, header, optional=None, payload=None): full_client_request = bytearray(header) if optional is not None: full_client_request.extend(optional) if payload is not None: payload_size = len(payload).to_bytes(4, 'big', signed=True) full_client_request.extend(payload_size) full_client_request.extend(payload) await ws.send(full_client_request) # 修改:TTS状态管理类,添加消息ID和任务追踪 class TTSState: def __init__(self, message_id: str): self.message_id = message_id self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None self.session_id: OptionalType[str] = None self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务 self.is_processing = False # 全局状态管理 class TTSManager: def __init__(self): # WebSocket -> 消息ID -> TTS状态 self.connections: Dict[any, Dict[str, TTSState]] = {} # 会话ID -> 消息ID 的映射,用于路由响应 self.session_to_message: Dict[str, str] = {} # 消息ID -> WebSocket 的映射 self.message_to_websocket: Dict[str, any] = {} def get_connection_states(self, websocket) -> Dict[str, TTSState]: """获取WebSocket连接的所有TTS状态""" if websocket not in self.connections: self.connections[websocket] = {} return self.connections[websocket] def add_tts_state(self, websocket, message_id: str) -> TTSState: """添加新的TTS状态""" states = self.get_connection_states(websocket) if message_id in states: # 如果已存在,先清理旧的 self.cleanup_message_state(websocket, message_id) tts_state = TTSState(message_id) states[message_id] = tts_state self.message_to_websocket[message_id] = websocket return tts_state def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]: """获取指定的TTS状态""" states = self.get_connection_states(websocket) return states.get(message_id) def register_session(self, session_id: str, message_id: str): """注册会话ID和消息ID的映射""" self.session_to_message[session_id] = message_id def get_message_by_session(self, session_id: str) -> OptionalType[str]: """根据会话ID获取消息ID""" return self.session_to_message.get(session_id) def get_websocket_by_message(self, message_id: str): """根据消息ID获取WebSocket""" return self.message_to_websocket.get(message_id) def cleanup_message_state(self, websocket, message_id: str): """清理指定消息的状态""" states = self.get_connection_states(websocket) if message_id in states: tts_state = states[message_id] # 取消任务 if tts_state.task and not tts_state.task.done(): tts_state.task.cancel() # 清理映射 if tts_state.session_id and tts_state.session_id in self.session_to_message: del self.session_to_message[tts_state.session_id] if message_id in self.message_to_websocket: del self.message_to_websocket[message_id] # 删除状态 del states[message_id] def cleanup_connection(self, websocket): """清理整个连接的状态""" if websocket in self.connections: states = self.connections[websocket] for message_id in list(states.keys()): self.cleanup_message_state(websocket, message_id) del self.connections[websocket] # 全局TTS管理器实例 tts_manager = TTSManager() # 初始化独立的TTS连接 async def create_tts_connection() -> websockets.WebSocketServerProtocol: """创建独立的TTS连接""" log_id = gen_log_id() ws_header = { "X-Api-App-Key": APP_ID, "X-Api-Access-Key": TOKEN, "X-Api-Resource-Id": 'volc.service_type.10029', "X-Api-Connect-Id": str(uuid.uuid4()), "X-Tt-Logid": log_id, } url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection' volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000) # 启动连接 header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes() optional = Optional(event=EVENT_Start_Connection).as_bytes() payload = str.encode("{}") await send_event(volc_ws, header, optional, payload) # 等待连接确认 raw_data = await volc_ws.recv() res = parser_response(raw_data) if res.optional.event != EVENT_ConnectionStarted: raise Exception("TTS连接失败") return volc_ws # 处理单个TTS任务 async def process_tts_task(websocket, message_id: str, text: str): """处理单个TTS任务(独立协程)""" tts_state = None try: print(f"开始处理TTS任务 [{message_id}]: {text}") # 获取TTS状态 tts_state = tts_manager.get_tts_state(websocket, message_id) if not tts_state: raise Exception(f"找不到TTS状态: {message_id}") tts_state.is_processing = True # 创建独立的TTS连接 tts_state.volc_ws = await create_tts_connection() # 创建会话 tts_state.session_id = uuid.uuid4().__str__().replace('-', '') tts_manager.register_session(tts_state.session_id, message_id) print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}") header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent, serial_method=JSON).as_bytes() optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes() payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER) await send_event(tts_state.volc_ws, header, optional, payload) raw_data = await tts_state.volc_ws.recv() res = parser_response(raw_data) if res.optional.event != EVENT_SessionStarted: raise Exception("TTS会话启动失败") print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}") # 发送文本到TTS服务 print(f"发送文本到TTS服务 [{message_id}]...") header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent, serial_method=JSON).as_bytes() optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes() payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER) await send_event(tts_state.volc_ws, header, optional, payload) # 接收TTS响应并发送到前端 print(f"开始接收TTS响应 [{message_id}]...") audio_count = 0 try: while True: raw_data = await asyncio.wait_for( tts_state.volc_ws.recv(), timeout=30 ) res = parser_response(raw_data) print(f"收到TTS事件 [{message_id}]: {res.optional.event}") if res.optional.event == EVENT_TTSSentenceEnd: print(f"句子结束事件 [{message_id}] - 直接完成") break elif res.optional.event == EVENT_SessionFinished: print(f"收到会话结束事件 [{message_id}]") break elif res.optional.event == EVENT_TTSResponse: audio_count += 1 print(f"发送音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}") # 发送音频数据,包含消息ID await websocket.send_json({ "type": "tts_audio_data", "messageId": message_id, "audioData": res.payload.hex() # 转为hex字符串 }) else: print(f"未知TTS事件 [{message_id}]: {res.optional.event}") except asyncio.TimeoutError: print(f"TTS响应超时 [{message_id}],强制结束") # 发送完成消息 await websocket.send_json({ "type": "tts_audio_complete", "messageId": message_id }) print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包") except asyncio.CancelledError: print(f"TTS任务被取消 [{message_id}]") await websocket.send_json({ "type": "tts_error", "messageId": message_id, "message": "TTS任务被取消" }) except Exception as e: print(f"TTS处理异常 [{message_id}]: {e}") import traceback traceback.print_exc() await websocket.send_json({ "type": "tts_error", "messageId": message_id, "message": f"TTS处理失败: {str(e)}" }) finally: # 清理资源 if tts_state: tts_state.is_processing = False if tts_state.volc_ws: try: await tts_state.volc_ws.close() except: pass # 清理状态 tts_manager.cleanup_message_state(websocket, message_id) # 启动TTS文本转换 async def handle_tts_text(websocket, message_id: str, text: str): """启动TTS文本转换""" # 创建新的TTS状态 tts_state = tts_manager.add_tts_state(websocket, message_id) # 启动异步任务 tts_state.task = asyncio.create_task( process_tts_task(websocket, message_id, text) ) # 取消TTS任务 async def handle_tts_cancel(websocket, message_id: str): """取消TTS任务""" tts_state = tts_manager.get_tts_state(websocket, message_id) if tts_state and tts_state.task and not tts_state.task.done(): tts_state.task.cancel() await websocket.send_json({ "type": "tts_complete", "messageId": message_id }) tts_manager.cleanup_message_state(websocket, message_id) # 清理连接的所有TTS资源 async def cleanup_connection_tts(websocket): """清理连接的所有TTS资源""" print(f"清理连接的TTS资源...") tts_manager.cleanup_connection(websocket) print("TTS资源清理完成")