456 lines
16 KiB
Python
456 lines
16 KiB
Python
# tts.py
|
||
import uuid
|
||
import websockets
|
||
import time
|
||
import fastrand
|
||
import json
|
||
import asyncio
|
||
from typing import Dict, Any, Optional as OptionalType
|
||
|
||
from app.constants.tts import APP_ID, TOKEN, SPEAKER
|
||
|
||
# 协议常量保持不变...
|
||
PROTOCOL_VERSION = 0b0001
|
||
DEFAULT_HEADER_SIZE = 0b0001
|
||
FULL_CLIENT_REQUEST = 0b0001
|
||
AUDIO_ONLY_RESPONSE = 0b1011
|
||
FULL_SERVER_RESPONSE = 0b1001
|
||
ERROR_INFORMATION = 0b1111
|
||
MsgTypeFlagWithEvent = 0b100
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
COMPRESSION_NO = 0b0000
|
||
|
||
# 事件类型
|
||
EVENT_NONE = 0
|
||
EVENT_Start_Connection = 1
|
||
EVENT_FinishConnection = 2
|
||
EVENT_ConnectionStarted = 50
|
||
EVENT_StartSession = 100
|
||
EVENT_FinishSession = 102
|
||
EVENT_SessionStarted = 150
|
||
EVENT_SessionFinished = 152
|
||
EVENT_TaskRequest = 200
|
||
EVENT_TTSSentenceEnd = 351
|
||
EVENT_TTSResponse = 352
|
||
|
||
|
||
# 所有类定义保持不变...
|
||
class Header:
|
||
def __init__(self,
|
||
protocol_version=PROTOCOL_VERSION,
|
||
header_size=DEFAULT_HEADER_SIZE,
|
||
message_type=0,
|
||
message_type_specific_flags=0,
|
||
serial_method=NO_SERIALIZATION,
|
||
compression_type=COMPRESSION_NO,
|
||
reserved_data=0):
|
||
self.header_size = header_size
|
||
self.protocol_version = protocol_version
|
||
self.message_type = message_type
|
||
self.message_type_specific_flags = message_type_specific_flags
|
||
self.serial_method = serial_method
|
||
self.compression_type = compression_type
|
||
self.reserved_data = reserved_data
|
||
|
||
def as_bytes(self) -> bytes:
|
||
return bytes([
|
||
(self.protocol_version << 4) | self.header_size,
|
||
(self.message_type << 4) | self.message_type_specific_flags,
|
||
(self.serial_method << 4) | self.compression_type,
|
||
self.reserved_data
|
||
])
|
||
|
||
|
||
class Optional:
|
||
def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None):
|
||
self.event = event
|
||
self.sessionId = sessionId
|
||
self.errorCode = 0
|
||
self.connectionId = None
|
||
self.response_meta_json = None
|
||
self.sequence = sequence
|
||
|
||
def as_bytes(self) -> bytes:
|
||
option_bytes = bytearray()
|
||
if self.event != EVENT_NONE:
|
||
option_bytes.extend(self.event.to_bytes(4, "big", signed=True))
|
||
if self.sessionId is not None:
|
||
session_id_bytes = str.encode(self.sessionId)
|
||
size = len(session_id_bytes).to_bytes(4, "big", signed=True)
|
||
option_bytes.extend(size)
|
||
option_bytes.extend(session_id_bytes)
|
||
if self.sequence is not None:
|
||
option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True))
|
||
return option_bytes
|
||
|
||
|
||
class Response:
|
||
def __init__(self, header, optional):
|
||
self.optional = optional
|
||
self.header = header
|
||
self.payload = None
|
||
self.payload_json = None
|
||
|
||
|
||
# 工具函数保持不变...
|
||
def gen_log_id():
|
||
"""生成logID"""
|
||
ts = int(time.time() * 1000)
|
||
r = fastrand.pcg32bounded(1 << 24) + (1 << 20)
|
||
local_ip = "00000000000000000000000000000000"
|
||
return f"02{ts}{local_ip}{r:08x}"
|
||
|
||
|
||
def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3',
|
||
audio_sample_rate=24000):
|
||
return str.encode(json.dumps({
|
||
"user": {"uid": uid},
|
||
"event": event,
|
||
"namespace": "BidirectionalTTS",
|
||
"req_params": {
|
||
"text": text,
|
||
"speaker": speaker,
|
||
"audio_params": {
|
||
"format": audio_format,
|
||
"sample_rate": audio_sample_rate,
|
||
"enable_timestamp": True,
|
||
}
|
||
}
|
||
}))
|
||
|
||
|
||
def read_res_content(res, offset):
|
||
content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
|
||
offset += 4
|
||
content = str(res[offset: offset + content_size], encoding='utf8')
|
||
offset += content_size
|
||
return content, offset
|
||
|
||
|
||
def read_res_payload(res, offset):
|
||
payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
|
||
offset += 4
|
||
payload = res[offset: offset + payload_size]
|
||
offset += payload_size
|
||
return payload, offset
|
||
|
||
|
||
def parser_response(res) -> Response:
|
||
if isinstance(res, str):
|
||
raise RuntimeError(res)
|
||
response = Response(Header(), Optional())
|
||
# 解析结果
|
||
header = response.header
|
||
num = 0b00001111
|
||
header.protocol_version = res[0] >> 4 & num
|
||
header.header_size = res[0] & 0x0f
|
||
header.message_type = (res[1] >> 4) & num
|
||
header.message_type_specific_flags = res[1] & 0x0f
|
||
header.serial_method = res[2] >> num
|
||
header.message_compression = res[2] & 0x0f
|
||
header.reserved_data = res[3]
|
||
|
||
offset = 4
|
||
optional = response.optional
|
||
if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE:
|
||
# read event
|
||
if header.message_type_specific_flags == MsgTypeFlagWithEvent:
|
||
optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True)
|
||
offset += 4
|
||
if optional.event == EVENT_NONE:
|
||
return response
|
||
# read connectionId
|
||
elif optional.event == EVENT_ConnectionStarted:
|
||
optional.connectionId, offset = read_res_content(res, offset)
|
||
elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished:
|
||
optional.sessionId, offset = read_res_content(res, offset)
|
||
optional.response_meta_json, offset = read_res_content(res, offset)
|
||
elif optional.event == EVENT_TTSResponse:
|
||
optional.sessionId, offset = read_res_content(res, offset)
|
||
response.payload, offset = read_res_payload(res, offset)
|
||
elif optional.event == EVENT_TTSSentenceEnd:
|
||
optional.sessionId, offset = read_res_content(res, offset)
|
||
response.payload_json, offset = read_res_content(res, offset)
|
||
|
||
elif header.message_type == ERROR_INFORMATION:
|
||
optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True)
|
||
offset += 4
|
||
response.payload, offset = read_res_payload(res, offset)
|
||
return response
|
||
|
||
|
||
async def send_event(ws, header, optional=None, payload=None):
|
||
full_client_request = bytearray(header)
|
||
if optional is not None:
|
||
full_client_request.extend(optional)
|
||
if payload is not None:
|
||
payload_size = len(payload).to_bytes(4, 'big', signed=True)
|
||
full_client_request.extend(payload_size)
|
||
full_client_request.extend(payload)
|
||
await ws.send(full_client_request)
|
||
|
||
|
||
# 修改:TTS状态管理类,添加消息ID和任务追踪
|
||
class TTSState:
|
||
def __init__(self, message_id: str):
|
||
self.message_id = message_id
|
||
self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None
|
||
self.session_id: OptionalType[str] = None
|
||
self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务
|
||
self.is_processing = False
|
||
|
||
|
||
# 全局状态管理
|
||
class TTSManager:
|
||
def __init__(self):
|
||
# WebSocket -> 消息ID -> TTS状态
|
||
self.connections: Dict[any, Dict[str, TTSState]] = {}
|
||
# 会话ID -> 消息ID 的映射,用于路由响应
|
||
self.session_to_message: Dict[str, str] = {}
|
||
# 消息ID -> WebSocket 的映射
|
||
self.message_to_websocket: Dict[str, any] = {}
|
||
|
||
def get_connection_states(self, websocket) -> Dict[str, TTSState]:
|
||
"""获取WebSocket连接的所有TTS状态"""
|
||
if websocket not in self.connections:
|
||
self.connections[websocket] = {}
|
||
return self.connections[websocket]
|
||
|
||
def add_tts_state(self, websocket, message_id: str) -> TTSState:
|
||
"""添加新的TTS状态"""
|
||
states = self.get_connection_states(websocket)
|
||
if message_id in states:
|
||
# 如果已存在,先清理旧的
|
||
self.cleanup_message_state(websocket, message_id)
|
||
|
||
tts_state = TTSState(message_id)
|
||
states[message_id] = tts_state
|
||
self.message_to_websocket[message_id] = websocket
|
||
return tts_state
|
||
|
||
def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]:
|
||
"""获取指定的TTS状态"""
|
||
states = self.get_connection_states(websocket)
|
||
return states.get(message_id)
|
||
|
||
def register_session(self, session_id: str, message_id: str):
|
||
"""注册会话ID和消息ID的映射"""
|
||
self.session_to_message[session_id] = message_id
|
||
|
||
def get_message_by_session(self, session_id: str) -> OptionalType[str]:
|
||
"""根据会话ID获取消息ID"""
|
||
return self.session_to_message.get(session_id)
|
||
|
||
def get_websocket_by_message(self, message_id: str):
|
||
"""根据消息ID获取WebSocket"""
|
||
return self.message_to_websocket.get(message_id)
|
||
|
||
def cleanup_message_state(self, websocket, message_id: str):
|
||
"""清理指定消息的状态"""
|
||
states = self.get_connection_states(websocket)
|
||
if message_id in states:
|
||
tts_state = states[message_id]
|
||
# 取消任务
|
||
if tts_state.task and not tts_state.task.done():
|
||
tts_state.task.cancel()
|
||
# 清理映射
|
||
if tts_state.session_id and tts_state.session_id in self.session_to_message:
|
||
del self.session_to_message[tts_state.session_id]
|
||
if message_id in self.message_to_websocket:
|
||
del self.message_to_websocket[message_id]
|
||
# 删除状态
|
||
del states[message_id]
|
||
|
||
def cleanup_connection(self, websocket):
|
||
"""清理整个连接的状态"""
|
||
if websocket in self.connections:
|
||
states = self.connections[websocket]
|
||
for message_id in list(states.keys()):
|
||
self.cleanup_message_state(websocket, message_id)
|
||
del self.connections[websocket]
|
||
|
||
|
||
# 全局TTS管理器实例
|
||
tts_manager = TTSManager()
|
||
|
||
|
||
# 初始化独立的TTS连接
|
||
async def create_tts_connection() -> websockets.WebSocketServerProtocol:
|
||
"""创建独立的TTS连接"""
|
||
log_id = gen_log_id()
|
||
ws_header = {
|
||
"X-Api-App-Key": APP_ID,
|
||
"X-Api-Access-Key": TOKEN,
|
||
"X-Api-Resource-Id": 'volc.service_type.10029',
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
"X-Tt-Logid": log_id,
|
||
}
|
||
url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection'
|
||
volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000)
|
||
|
||
# 启动连接
|
||
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||
message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes()
|
||
optional = Optional(event=EVENT_Start_Connection).as_bytes()
|
||
payload = str.encode("{}")
|
||
await send_event(volc_ws, header, optional, payload)
|
||
|
||
# 等待连接确认
|
||
raw_data = await volc_ws.recv()
|
||
res = parser_response(raw_data)
|
||
if res.optional.event != EVENT_ConnectionStarted:
|
||
raise Exception("TTS连接失败")
|
||
|
||
return volc_ws
|
||
|
||
|
||
# 处理单个TTS任务
|
||
async def process_tts_task(websocket, message_id: str, text: str):
|
||
"""处理单个TTS任务(独立协程)"""
|
||
tts_state = None
|
||
try:
|
||
print(f"开始处理TTS任务 [{message_id}]: {text}")
|
||
|
||
# 获取TTS状态
|
||
tts_state = tts_manager.get_tts_state(websocket, message_id)
|
||
if not tts_state:
|
||
raise Exception(f"找不到TTS状态: {message_id}")
|
||
|
||
tts_state.is_processing = True
|
||
|
||
# 创建独立的TTS连接
|
||
tts_state.volc_ws = await create_tts_connection()
|
||
|
||
# 创建会话
|
||
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
|
||
tts_manager.register_session(tts_state.session_id, message_id)
|
||
|
||
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
|
||
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||
serial_method=JSON).as_bytes()
|
||
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
|
||
payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER)
|
||
await send_event(tts_state.volc_ws, header, optional, payload)
|
||
|
||
raw_data = await tts_state.volc_ws.recv()
|
||
res = parser_response(raw_data)
|
||
if res.optional.event != EVENT_SessionStarted:
|
||
raise Exception("TTS会话启动失败")
|
||
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
|
||
|
||
# 发送文本到TTS服务
|
||
print(f"发送文本到TTS服务 [{message_id}]...")
|
||
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||
serial_method=JSON).as_bytes()
|
||
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
|
||
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER)
|
||
await send_event(tts_state.volc_ws, header, optional, payload)
|
||
|
||
# 接收TTS响应并发送到前端
|
||
print(f"开始接收TTS响应 [{message_id}]...")
|
||
audio_count = 0
|
||
|
||
try:
|
||
while True:
|
||
raw_data = await asyncio.wait_for(
|
||
tts_state.volc_ws.recv(),
|
||
timeout=30
|
||
)
|
||
res = parser_response(raw_data)
|
||
|
||
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
|
||
|
||
if res.optional.event == EVENT_TTSSentenceEnd:
|
||
print(f"句子结束事件 [{message_id}] - 直接完成")
|
||
break
|
||
|
||
elif res.optional.event == EVENT_SessionFinished:
|
||
print(f"收到会话结束事件 [{message_id}]")
|
||
break
|
||
|
||
elif res.optional.event == EVENT_TTSResponse:
|
||
audio_count += 1
|
||
print(f"发送音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
|
||
# 发送音频数据,包含消息ID
|
||
await websocket.send_json({
|
||
"type": "tts_audio_data",
|
||
"messageId": message_id,
|
||
"audioData": res.payload.hex() # 转为hex字符串
|
||
})
|
||
else:
|
||
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
|
||
|
||
except asyncio.TimeoutError:
|
||
print(f"TTS响应超时 [{message_id}],强制结束")
|
||
|
||
# 发送完成消息
|
||
await websocket.send_json({
|
||
"type": "tts_audio_complete",
|
||
"messageId": message_id
|
||
})
|
||
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包")
|
||
|
||
except asyncio.CancelledError:
|
||
print(f"TTS任务被取消 [{message_id}]")
|
||
await websocket.send_json({
|
||
"type": "tts_error",
|
||
"messageId": message_id,
|
||
"message": "TTS任务被取消"
|
||
})
|
||
except Exception as e:
|
||
print(f"TTS处理异常 [{message_id}]: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await websocket.send_json({
|
||
"type": "tts_error",
|
||
"messageId": message_id,
|
||
"message": f"TTS处理失败: {str(e)}"
|
||
})
|
||
finally:
|
||
# 清理资源
|
||
if tts_state:
|
||
tts_state.is_processing = False
|
||
if tts_state.volc_ws:
|
||
try:
|
||
await tts_state.volc_ws.close()
|
||
except:
|
||
pass
|
||
# 清理状态
|
||
tts_manager.cleanup_message_state(websocket, message_id)
|
||
|
||
|
||
# 启动TTS文本转换
|
||
async def handle_tts_text(websocket, message_id: str, text: str):
|
||
"""启动TTS文本转换"""
|
||
# 创建新的TTS状态
|
||
tts_state = tts_manager.add_tts_state(websocket, message_id)
|
||
|
||
# 启动异步任务
|
||
tts_state.task = asyncio.create_task(
|
||
process_tts_task(websocket, message_id, text)
|
||
)
|
||
|
||
|
||
# 取消TTS任务
|
||
async def handle_tts_cancel(websocket, message_id: str):
|
||
"""取消TTS任务"""
|
||
tts_state = tts_manager.get_tts_state(websocket, message_id)
|
||
if tts_state and tts_state.task and not tts_state.task.done():
|
||
tts_state.task.cancel()
|
||
await websocket.send_json({
|
||
"type": "tts_complete",
|
||
"messageId": message_id
|
||
})
|
||
tts_manager.cleanup_message_state(websocket, message_id)
|
||
|
||
|
||
# 清理连接的所有TTS资源
|
||
async def cleanup_connection_tts(websocket):
|
||
"""清理连接的所有TTS资源"""
|
||
print(f"清理连接的TTS资源...")
|
||
tts_manager.cleanup_connection(websocket)
|
||
print("TTS资源清理完成")
|