Files
Practical_Training_Assignment/backend/app/api/v1/endpoints/tts.py
2025-07-01 01:27:29 +08:00

497 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
import websockets
import time
import fastrand
import json
import asyncio
import os
import aiofiles
from datetime import datetime
from typing import Dict, Any, Optional as OptionalType
from app.constants.tts import APP_ID, TOKEN, SPEAKER
# 协议常量
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_RESPONSE = 0b1011
FULL_SERVER_RESPONSE = 0b1001
ERROR_INFORMATION = 0b1111
MsgTypeFlagWithEvent = 0b100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
COMPRESSION_NO = 0b0000
# 事件类型
EVENT_NONE = 0
EVENT_Start_Connection = 1
EVENT_FinishConnection = 2
EVENT_ConnectionStarted = 50
EVENT_StartSession = 100
EVENT_FinishSession = 102
EVENT_SessionStarted = 150
EVENT_SessionFinished = 152
EVENT_TaskRequest = 200
EVENT_TTSSentenceEnd = 351
EVENT_TTSResponse = 352
# 音频文件保存目录
TEMP_AUDIO_DIR = "./temp_audio"
# 确保音频目录存在
async def ensure_audio_dir():
"""异步创建音频目录"""
if not os.path.exists(TEMP_AUDIO_DIR):
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
# 生成时间戳文件名
def generate_audio_filename() -> str:
"""生成基于时间戳的音频文件名"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # 精确到毫秒
return f"{timestamp}.mp3"
# ... 保留所有原有的类定义和工具函数 ...
class Header:
def __init__(self,
protocol_version=PROTOCOL_VERSION,
header_size=DEFAULT_HEADER_SIZE,
message_type=0,
message_type_specific_flags=0,
serial_method=NO_SERIALIZATION,
compression_type=COMPRESSION_NO,
reserved_data=0):
self.header_size = header_size
self.protocol_version = protocol_version
self.message_type = message_type
self.message_type_specific_flags = message_type_specific_flags
self.serial_method = serial_method
self.compression_type = compression_type
self.reserved_data = reserved_data
def as_bytes(self) -> bytes:
return bytes([
(self.protocol_version << 4) | self.header_size,
(self.message_type << 4) | self.message_type_specific_flags,
(self.serial_method << 4) | self.compression_type,
self.reserved_data
])
class Optional:
def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None):
self.event = event
self.sessionId = sessionId
self.errorCode = 0
self.connectionId = None
self.response_meta_json = None
self.sequence = sequence
def as_bytes(self) -> bytes:
option_bytes = bytearray()
if self.event != EVENT_NONE:
option_bytes.extend(self.event.to_bytes(4, "big", signed=True))
if self.sessionId is not None:
session_id_bytes = str.encode(self.sessionId)
size = len(session_id_bytes).to_bytes(4, "big", signed=True)
option_bytes.extend(size)
option_bytes.extend(session_id_bytes)
if self.sequence is not None:
option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True))
return option_bytes
class Response:
def __init__(self, header, optional):
self.optional = optional
self.header = header
self.payload = None
self.payload_json = None
# 工具函数
def gen_log_id():
"""生成logID"""
ts = int(time.time() * 1000)
r = fastrand.pcg32bounded(1 << 24) + (1 << 20)
local_ip = "00000000000000000000000000000000"
return f"02{ts}{local_ip}{r:08x}"
def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3',
audio_sample_rate=24000):
return str.encode(json.dumps({
"user": {"uid": uid},
"event": event,
"namespace": "BidirectionalTTS",
"req_params": {
"text": text,
"speaker": speaker,
"audio_params": {
"format": audio_format,
"sample_rate": audio_sample_rate,
"enable_timestamp": True,
}
}
}))
def read_res_content(res, offset):
content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
offset += 4
content = str(res[offset: offset + content_size], encoding='utf8')
offset += content_size
return content, offset
def read_res_payload(res, offset):
payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
offset += 4
payload = res[offset: offset + payload_size]
offset += payload_size
return payload, offset
def parser_response(res) -> Response:
if isinstance(res, str):
raise RuntimeError(res)
response = Response(Header(), Optional())
# 解析结果
header = response.header
num = 0b00001111
header.protocol_version = res[0] >> 4 & num
header.header_size = res[0] & 0x0f
header.message_type = (res[1] >> 4) & num
header.message_type_specific_flags = res[1] & 0x0f
header.serial_method = res[2] >> num
header.message_compression = res[2] & 0x0f
header.reserved_data = res[3]
offset = 4
optional = response.optional
if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE:
# read event
if header.message_type_specific_flags == MsgTypeFlagWithEvent:
optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True)
offset += 4
if optional.event == EVENT_NONE:
return response
# read connectionId
elif optional.event == EVENT_ConnectionStarted:
optional.connectionId, offset = read_res_content(res, offset)
elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished:
optional.sessionId, offset = read_res_content(res, offset)
optional.response_meta_json, offset = read_res_content(res, offset)
elif optional.event == EVENT_TTSResponse:
optional.sessionId, offset = read_res_content(res, offset)
response.payload, offset = read_res_payload(res, offset)
elif optional.event == EVENT_TTSSentenceEnd:
optional.sessionId, offset = read_res_content(res, offset)
response.payload_json, offset = read_res_content(res, offset)
elif header.message_type == ERROR_INFORMATION:
optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True)
offset += 4
response.payload, offset = read_res_payload(res, offset)
return response
async def send_event(ws, header, optional=None, payload=None):
full_client_request = bytearray(header)
if optional is not None:
full_client_request.extend(optional)
if payload is not None:
payload_size = len(payload).to_bytes(4, 'big', signed=True)
full_client_request.extend(payload_size)
full_client_request.extend(payload)
await ws.send(full_client_request)
# TTS状态管理类添加消息ID和任务追踪
class TTSState:
def __init__(self, message_id: str):
self.message_id = message_id
self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None
self.session_id: OptionalType[str] = None
self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务
self.is_processing = False
self.audio_data = bytearray() # 用于收集音频数据
self.audio_filename = None # 保存的文件名
# 全局状态管理
class TTSManager:
def __init__(self):
# WebSocket -> 消息ID -> TTS状态
self.connections: Dict[any, Dict[str, TTSState]] = {}
# 会话ID -> 消息ID 的映射,用于路由响应
self.session_to_message: Dict[str, str] = {}
# 消息ID -> WebSocket 的映射
self.message_to_websocket: Dict[str, any] = {}
def get_connection_states(self, websocket) -> Dict[str, TTSState]:
"""获取WebSocket连接的所有TTS状态"""
if websocket not in self.connections:
self.connections[websocket] = {}
return self.connections[websocket]
def add_tts_state(self, websocket, message_id: str) -> TTSState:
"""添加新的TTS状态"""
states = self.get_connection_states(websocket)
if message_id in states:
# 如果已存在,先清理旧的
self.cleanup_message_state(websocket, message_id)
tts_state = TTSState(message_id)
states[message_id] = tts_state
self.message_to_websocket[message_id] = websocket
return tts_state
def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]:
"""获取指定的TTS状态"""
states = self.get_connection_states(websocket)
return states.get(message_id)
def register_session(self, session_id: str, message_id: str):
"""注册会话ID和消息ID的映射"""
self.session_to_message[session_id] = message_id
def get_message_by_session(self, session_id: str) -> OptionalType[str]:
"""根据会话ID获取消息ID"""
return self.session_to_message.get(session_id)
def get_websocket_by_message(self, message_id: str):
"""根据消息ID获取WebSocket"""
return self.message_to_websocket.get(message_id)
def cleanup_message_state(self, websocket, message_id: str):
"""清理指定消息的状态"""
states = self.get_connection_states(websocket)
if message_id in states:
tts_state = states[message_id]
# 取消任务
if tts_state.task and not tts_state.task.done():
tts_state.task.cancel()
# 清理映射
if tts_state.session_id and tts_state.session_id in self.session_to_message:
del self.session_to_message[tts_state.session_id]
if message_id in self.message_to_websocket:
del self.message_to_websocket[message_id]
# 删除状态
del states[message_id]
def cleanup_connection(self, websocket):
"""清理整个连接的状态"""
if websocket in self.connections:
states = self.connections[websocket]
for message_id in list(states.keys()):
self.cleanup_message_state(websocket, message_id)
del self.connections[websocket]
# 全局TTS管理器实例
tts_manager = TTSManager()
# 初始化独立的TTS连接
async def create_tts_connection() -> websockets.WebSocketServerProtocol:
"""创建独立的TTS连接"""
log_id = gen_log_id()
ws_header = {
"X-Api-App-Key": APP_ID,
"X-Api-Access-Key": TOKEN,
"X-Api-Resource-Id": 'volc.service_type.10029',
"X-Api-Connect-Id": str(uuid.uuid4()),
"X-Tt-Logid": log_id,
}
url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection'
volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000)
# 启动连接
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes()
optional = Optional(event=EVENT_Start_Connection).as_bytes()
payload = str.encode("{}")
await send_event(volc_ws, header, optional, payload)
# 等待连接确认
raw_data = await volc_ws.recv()
res = parser_response(raw_data)
if res.optional.event != EVENT_ConnectionStarted:
raise Exception("TTS连接失败")
return volc_ws
# 异步保存音频文件
async def save_audio_file(audio_data: bytes, filename: str) -> str:
"""异步保存音频文件"""
await ensure_audio_dir()
file_path = os.path.join(TEMP_AUDIO_DIR, filename)
async with aiofiles.open(file_path, 'wb') as f:
await f.write(audio_data)
return file_path
# 处理单个TTS任务
async def process_tts_task(websocket, message_id: str, text: str, speaker: str = None):
"""处理单个TTS任务独立协程"""
tts_state = None
# 使用传入的speaker如果没有则使用默认的
selected_speaker = speaker if speaker else SPEAKER
try:
print(f"开始处理TTS任务 [{message_id}]: {text}, 使用说话人: {selected_speaker}")
# 获取TTS状态
tts_state = tts_manager.get_tts_state(websocket, message_id)
if not tts_state:
raise Exception(f"找不到TTS状态: {message_id}")
tts_state.is_processing = True
# 生成音频文件名
tts_state.audio_filename = generate_audio_filename()
# 创建独立的TTS连接
tts_state.volc_ws = await create_tts_connection()
# 创建会话
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
tts_manager.register_session(tts_state.session_id, message_id)
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
# 使用选择的speaker
payload = get_payload_bytes(event=EVENT_StartSession, speaker=selected_speaker)
await send_event(tts_state.volc_ws, header, optional, payload)
raw_data = await tts_state.volc_ws.recv()
res = parser_response(raw_data)
if res.optional.event != EVENT_SessionStarted:
raise Exception("TTS会话启动失败")
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
# 发送文本到TTS服务
print(f"发送文本到TTS服务 [{message_id}]...")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
# 使用选择的speaker
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=selected_speaker)
await send_event(tts_state.volc_ws, header, optional, payload)
# 接收TTS响应并发送到前端
print(f"开始接收TTS响应 [{message_id}]...")
audio_count = 0
try:
while True:
raw_data = await asyncio.wait_for(
tts_state.volc_ws.recv(),
timeout=30
)
res = parser_response(raw_data)
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
if res.optional.event == EVENT_TTSSentenceEnd:
print(f"句子结束事件 [{message_id}] - 直接完成")
break
elif res.optional.event == EVENT_SessionFinished:
print(f"收到会话结束事件 [{message_id}]")
break
elif res.optional.event == EVENT_TTSResponse:
audio_count += 1
print(f"收到音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
# 收集音频数据
tts_state.audio_data.extend(res.payload)
# 发送音频数据到前端
await websocket.send_json({
"id": audio_count,
"type": "tts_audio_data",
"messageId": message_id,
"audioData": res.payload.hex() # 转为hex字符串
})
else:
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
except asyncio.TimeoutError:
print(f"TTS响应超时 [{message_id}],强制结束")
# 异步保存音频文件
if tts_state.audio_data:
file_path = await save_audio_file(
bytes(tts_state.audio_data),
tts_state.audio_filename
)
print(f"音频文件已保存 [{message_id}]: {file_path}")
# 发送完成消息包含文件路径和使用的speaker
await websocket.send_json({
"type": "tts_audio_complete",
"messageId": message_id,
"audioFile": tts_state.audio_filename,
"audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None,
"speaker": selected_speaker
})
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包,使用说话人: {selected_speaker}")
except asyncio.CancelledError:
print(f"TTS任务被取消 [{message_id}]")
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": "TTS任务被取消"
})
except Exception as e:
print(f"TTS处理异常 [{message_id}]: {e}")
import traceback
traceback.print_exc()
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": f"TTS处理失败: {str(e)}"
})
finally:
# 清理资源
if tts_state:
tts_state.is_processing = False
if tts_state.volc_ws:
try:
await tts_state.volc_ws.close()
except:
pass
# 清理状态
tts_manager.cleanup_message_state(websocket, message_id)
# 启动TTS文本转换
async def handle_tts_text(websocket, message_id: str, text: str, speaker: str = None):
"""启动TTS文本转换"""
# 创建新的TTS状态
print(speaker)
tts_state = tts_manager.add_tts_state(websocket, message_id)
# 启动异步任务传入speaker参数
tts_state.task = asyncio.create_task(
process_tts_task(websocket, message_id, text, speaker)
)
# 取消TTS任务
async def handle_tts_cancel(websocket, message_id: str):
"""取消TTS任务"""
tts_state = tts_manager.get_tts_state(websocket, message_id)
if tts_state and tts_state.task and not tts_state.task.done():
tts_state.task.cancel()
await websocket.send_json({
"type": "tts_complete",
"messageId": message_id
})
tts_manager.cleanup_message_state(websocket, message_id)
# 清理连接的所有TTS资源
async def cleanup_connection_tts(websocket):
"""清理连接的所有TTS资源"""
print(f"清理连接的TTS资源...")
tts_manager.cleanup_connection(websocket)
print("TTS资源清理完成")