import uuid import websockets import time import fastrand import json import asyncio import os import aiofiles from datetime import datetime from typing import Dict, Any, Optional as OptionalType from app.constants.tts import APP_ID, TOKEN, SPEAKER # 协议常量 PROTOCOL_VERSION = 0b0001 DEFAULT_HEADER_SIZE = 0b0001 FULL_CLIENT_REQUEST = 0b0001 AUDIO_ONLY_RESPONSE = 0b1011 FULL_SERVER_RESPONSE = 0b1001 ERROR_INFORMATION = 0b1111 MsgTypeFlagWithEvent = 0b100 NO_SERIALIZATION = 0b0000 JSON = 0b0001 COMPRESSION_NO = 0b0000 # 事件类型 EVENT_NONE = 0 EVENT_Start_Connection = 1 EVENT_FinishConnection = 2 EVENT_ConnectionStarted = 50 EVENT_StartSession = 100 EVENT_FinishSession = 102 EVENT_SessionStarted = 150 EVENT_SessionFinished = 152 EVENT_TaskRequest = 200 EVENT_TTSSentenceEnd = 351 EVENT_TTSResponse = 352 # 音频文件保存目录 TEMP_AUDIO_DIR = "./temp_audio" # 确保音频目录存在 async def ensure_audio_dir(): """异步创建音频目录""" if not os.path.exists(TEMP_AUDIO_DIR): os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) # 生成时间戳文件名 def generate_audio_filename() -> str: """生成基于时间戳的音频文件名""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # 精确到毫秒 return f"{timestamp}.mp3" # ... 保留所有原有的类定义和工具函数 ... class Header: def __init__(self, protocol_version=PROTOCOL_VERSION, header_size=DEFAULT_HEADER_SIZE, message_type=0, message_type_specific_flags=0, serial_method=NO_SERIALIZATION, compression_type=COMPRESSION_NO, reserved_data=0): self.header_size = header_size self.protocol_version = protocol_version self.message_type = message_type self.message_type_specific_flags = message_type_specific_flags self.serial_method = serial_method self.compression_type = compression_type self.reserved_data = reserved_data def as_bytes(self) -> bytes: return bytes([ (self.protocol_version << 4) | self.header_size, (self.message_type << 4) | self.message_type_specific_flags, (self.serial_method << 4) | self.compression_type, self.reserved_data ]) class Optional: def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None): self.event = event self.sessionId = sessionId self.errorCode = 0 self.connectionId = None self.response_meta_json = None self.sequence = sequence def as_bytes(self) -> bytes: option_bytes = bytearray() if self.event != EVENT_NONE: option_bytes.extend(self.event.to_bytes(4, "big", signed=True)) if self.sessionId is not None: session_id_bytes = str.encode(self.sessionId) size = len(session_id_bytes).to_bytes(4, "big", signed=True) option_bytes.extend(size) option_bytes.extend(session_id_bytes) if self.sequence is not None: option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True)) return option_bytes class Response: def __init__(self, header, optional): self.optional = optional self.header = header self.payload = None self.payload_json = None # 工具函数 def gen_log_id(): """生成logID""" ts = int(time.time() * 1000) r = fastrand.pcg32bounded(1 << 24) + (1 << 20) local_ip = "00000000000000000000000000000000" return f"02{ts}{local_ip}{r:08x}" def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3', audio_sample_rate=24000): return str.encode(json.dumps({ "user": {"uid": uid}, "event": event, "namespace": "BidirectionalTTS", "req_params": { "text": text, "speaker": speaker, "audio_params": { "format": audio_format, "sample_rate": audio_sample_rate, "enable_timestamp": True, } } })) def read_res_content(res, offset): content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) offset += 4 content = str(res[offset: offset + content_size], encoding='utf8') offset += content_size return content, offset def read_res_payload(res, offset): payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True) offset += 4 payload = res[offset: offset + payload_size] offset += payload_size return payload, offset def parser_response(res) -> Response: if isinstance(res, str): raise RuntimeError(res) response = Response(Header(), Optional()) # 解析结果 header = response.header num = 0b00001111 header.protocol_version = res[0] >> 4 & num header.header_size = res[0] & 0x0f header.message_type = (res[1] >> 4) & num header.message_type_specific_flags = res[1] & 0x0f header.serial_method = res[2] >> num header.message_compression = res[2] & 0x0f header.reserved_data = res[3] offset = 4 optional = response.optional if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE: # read event if header.message_type_specific_flags == MsgTypeFlagWithEvent: optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True) offset += 4 if optional.event == EVENT_NONE: return response # read connectionId elif optional.event == EVENT_ConnectionStarted: optional.connectionId, offset = read_res_content(res, offset) elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished: optional.sessionId, offset = read_res_content(res, offset) optional.response_meta_json, offset = read_res_content(res, offset) elif optional.event == EVENT_TTSResponse: optional.sessionId, offset = read_res_content(res, offset) response.payload, offset = read_res_payload(res, offset) elif optional.event == EVENT_TTSSentenceEnd: optional.sessionId, offset = read_res_content(res, offset) response.payload_json, offset = read_res_content(res, offset) elif header.message_type == ERROR_INFORMATION: optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True) offset += 4 response.payload, offset = read_res_payload(res, offset) return response async def send_event(ws, header, optional=None, payload=None): full_client_request = bytearray(header) if optional is not None: full_client_request.extend(optional) if payload is not None: payload_size = len(payload).to_bytes(4, 'big', signed=True) full_client_request.extend(payload_size) full_client_request.extend(payload) await ws.send(full_client_request) # TTS状态管理类,添加消息ID和任务追踪 class TTSState: def __init__(self, message_id: str): self.message_id = message_id self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None self.session_id: OptionalType[str] = None self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务 self.is_processing = False self.audio_data = bytearray() # 用于收集音频数据 self.audio_filename = None # 保存的文件名 # 全局状态管理 class TTSManager: def __init__(self): # WebSocket -> 消息ID -> TTS状态 self.connections: Dict[any, Dict[str, TTSState]] = {} # 会话ID -> 消息ID 的映射,用于路由响应 self.session_to_message: Dict[str, str] = {} # 消息ID -> WebSocket 的映射 self.message_to_websocket: Dict[str, any] = {} def get_connection_states(self, websocket) -> Dict[str, TTSState]: """获取WebSocket连接的所有TTS状态""" if websocket not in self.connections: self.connections[websocket] = {} return self.connections[websocket] def add_tts_state(self, websocket, message_id: str) -> TTSState: """添加新的TTS状态""" states = self.get_connection_states(websocket) if message_id in states: # 如果已存在,先清理旧的 self.cleanup_message_state(websocket, message_id) tts_state = TTSState(message_id) states[message_id] = tts_state self.message_to_websocket[message_id] = websocket return tts_state def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]: """获取指定的TTS状态""" states = self.get_connection_states(websocket) return states.get(message_id) def register_session(self, session_id: str, message_id: str): """注册会话ID和消息ID的映射""" self.session_to_message[session_id] = message_id def get_message_by_session(self, session_id: str) -> OptionalType[str]: """根据会话ID获取消息ID""" return self.session_to_message.get(session_id) def get_websocket_by_message(self, message_id: str): """根据消息ID获取WebSocket""" return self.message_to_websocket.get(message_id) def cleanup_message_state(self, websocket, message_id: str): """清理指定消息的状态""" states = self.get_connection_states(websocket) if message_id in states: tts_state = states[message_id] # 取消任务 if tts_state.task and not tts_state.task.done(): tts_state.task.cancel() # 清理映射 if tts_state.session_id and tts_state.session_id in self.session_to_message: del self.session_to_message[tts_state.session_id] if message_id in self.message_to_websocket: del self.message_to_websocket[message_id] # 删除状态 del states[message_id] def cleanup_connection(self, websocket): """清理整个连接的状态""" if websocket in self.connections: states = self.connections[websocket] for message_id in list(states.keys()): self.cleanup_message_state(websocket, message_id) del self.connections[websocket] # 全局TTS管理器实例 tts_manager = TTSManager() # 初始化独立的TTS连接 async def create_tts_connection() -> websockets.WebSocketServerProtocol: """创建独立的TTS连接""" log_id = gen_log_id() ws_header = { "X-Api-App-Key": APP_ID, "X-Api-Access-Key": TOKEN, "X-Api-Resource-Id": 'volc.service_type.10029', "X-Api-Connect-Id": str(uuid.uuid4()), "X-Tt-Logid": log_id, } url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection' volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000) # 启动连接 header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes() optional = Optional(event=EVENT_Start_Connection).as_bytes() payload = str.encode("{}") await send_event(volc_ws, header, optional, payload) # 等待连接确认 raw_data = await volc_ws.recv() res = parser_response(raw_data) if res.optional.event != EVENT_ConnectionStarted: raise Exception("TTS连接失败") return volc_ws # 异步保存音频文件 async def save_audio_file(audio_data: bytes, filename: str) -> str: """异步保存音频文件""" await ensure_audio_dir() file_path = os.path.join(TEMP_AUDIO_DIR, filename) async with aiofiles.open(file_path, 'wb') as f: await f.write(audio_data) return file_path # 处理单个TTS任务 async def process_tts_task(websocket, message_id: str, text: str): """处理单个TTS任务(独立协程)""" tts_state = None try: print(f"开始处理TTS任务 [{message_id}]: {text}") # 获取TTS状态 tts_state = tts_manager.get_tts_state(websocket, message_id) if not tts_state: raise Exception(f"找不到TTS状态: {message_id}") tts_state.is_processing = True # 生成音频文件名 tts_state.audio_filename = generate_audio_filename() # 创建独立的TTS连接 tts_state.volc_ws = await create_tts_connection() # 创建会话 tts_state.session_id = uuid.uuid4().__str__().replace('-', '') tts_manager.register_session(tts_state.session_id, message_id) print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}") header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent, serial_method=JSON).as_bytes() optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes() payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER) await send_event(tts_state.volc_ws, header, optional, payload) raw_data = await tts_state.volc_ws.recv() res = parser_response(raw_data) if res.optional.event != EVENT_SessionStarted: raise Exception("TTS会话启动失败") print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}") # 发送文本到TTS服务 print(f"发送文本到TTS服务 [{message_id}]...") header = Header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=MsgTypeFlagWithEvent, serial_method=JSON).as_bytes() optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes() payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER) await send_event(tts_state.volc_ws, header, optional, payload) # 接收TTS响应并发送到前端 print(f"开始接收TTS响应 [{message_id}]...") audio_count = 0 try: while True: raw_data = await asyncio.wait_for( tts_state.volc_ws.recv(), timeout=30 ) res = parser_response(raw_data) print(f"收到TTS事件 [{message_id}]: {res.optional.event}") if res.optional.event == EVENT_TTSSentenceEnd: print(f"句子结束事件 [{message_id}] - 直接完成") break elif res.optional.event == EVENT_SessionFinished: print(f"收到会话结束事件 [{message_id}]") break elif res.optional.event == EVENT_TTSResponse: audio_count += 1 print(f"收到音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}") # 收集音频数据 tts_state.audio_data.extend(res.payload) # 发送音频数据到前端 await websocket.send_json({ "id": audio_count, "type": "tts_audio_data", "messageId": message_id, "audioData": res.payload.hex() # 转为hex字符串 }) else: print(f"未知TTS事件 [{message_id}]: {res.optional.event}") except asyncio.TimeoutError: print(f"TTS响应超时 [{message_id}],强制结束") # 异步保存音频文件 if tts_state.audio_data: file_path = await save_audio_file( bytes(tts_state.audio_data), tts_state.audio_filename ) print(f"音频文件已保存 [{message_id}]: {file_path}") # 发送完成消息,包含文件路径 await websocket.send_json({ "type": "tts_audio_complete", "messageId": message_id, "audioFile": tts_state.audio_filename, "audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None }) print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包") except asyncio.CancelledError: print(f"TTS任务被取消 [{message_id}]") await websocket.send_json({ "type": "tts_error", "messageId": message_id, "message": "TTS任务被取消" }) except Exception as e: print(f"TTS处理异常 [{message_id}]: {e}") import traceback traceback.print_exc() await websocket.send_json({ "type": "tts_error", "messageId": message_id, "message": f"TTS处理失败: {str(e)}" }) finally: # 清理资源 if tts_state: tts_state.is_processing = False if tts_state.volc_ws: try: await tts_state.volc_ws.close() except: pass # 清理状态 tts_manager.cleanup_message_state(websocket, message_id) # 启动TTS文本转换 async def handle_tts_text(websocket, message_id: str, text: str): """启动TTS文本转换""" # 创建新的TTS状态 tts_state = tts_manager.add_tts_state(websocket, message_id) # 启动异步任务 tts_state.task = asyncio.create_task( process_tts_task(websocket, message_id, text) ) # 取消TTS任务 async def handle_tts_cancel(websocket, message_id: str): """取消TTS任务""" tts_state = tts_manager.get_tts_state(websocket, message_id) if tts_state and tts_state.task and not tts_state.task.done(): tts_state.task.cancel() await websocket.send_json({ "type": "tts_complete", "messageId": message_id }) tts_manager.cleanup_message_state(websocket, message_id) # 清理连接的所有TTS资源 async def cleanup_connection_tts(websocket): """清理连接的所有TTS资源""" print(f"清理连接的TTS资源...") tts_manager.cleanup_connection(websocket) print("TTS资源清理完成")