feat: tts语音生成
This commit is contained in:
@@ -0,0 +1,455 @@
|
|||||||
|
# tts.py
|
||||||
|
import uuid
|
||||||
|
import websockets
|
||||||
|
import time
|
||||||
|
import fastrand
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, Optional as OptionalType
|
||||||
|
|
||||||
|
from app.constants.tts import APP_ID, TOKEN, SPEAKER
|
||||||
|
|
||||||
|
# 协议常量保持不变...
|
||||||
|
PROTOCOL_VERSION = 0b0001
|
||||||
|
DEFAULT_HEADER_SIZE = 0b0001
|
||||||
|
FULL_CLIENT_REQUEST = 0b0001
|
||||||
|
AUDIO_ONLY_RESPONSE = 0b1011
|
||||||
|
FULL_SERVER_RESPONSE = 0b1001
|
||||||
|
ERROR_INFORMATION = 0b1111
|
||||||
|
MsgTypeFlagWithEvent = 0b100
|
||||||
|
NO_SERIALIZATION = 0b0000
|
||||||
|
JSON = 0b0001
|
||||||
|
COMPRESSION_NO = 0b0000
|
||||||
|
|
||||||
|
# 事件类型
|
||||||
|
EVENT_NONE = 0
|
||||||
|
EVENT_Start_Connection = 1
|
||||||
|
EVENT_FinishConnection = 2
|
||||||
|
EVENT_ConnectionStarted = 50
|
||||||
|
EVENT_StartSession = 100
|
||||||
|
EVENT_FinishSession = 102
|
||||||
|
EVENT_SessionStarted = 150
|
||||||
|
EVENT_SessionFinished = 152
|
||||||
|
EVENT_TaskRequest = 200
|
||||||
|
EVENT_TTSSentenceEnd = 351
|
||||||
|
EVENT_TTSResponse = 352
|
||||||
|
|
||||||
|
|
||||||
|
# 所有类定义保持不变...
|
||||||
|
class Header:
|
||||||
|
def __init__(self,
|
||||||
|
protocol_version=PROTOCOL_VERSION,
|
||||||
|
header_size=DEFAULT_HEADER_SIZE,
|
||||||
|
message_type=0,
|
||||||
|
message_type_specific_flags=0,
|
||||||
|
serial_method=NO_SERIALIZATION,
|
||||||
|
compression_type=COMPRESSION_NO,
|
||||||
|
reserved_data=0):
|
||||||
|
self.header_size = header_size
|
||||||
|
self.protocol_version = protocol_version
|
||||||
|
self.message_type = message_type
|
||||||
|
self.message_type_specific_flags = message_type_specific_flags
|
||||||
|
self.serial_method = serial_method
|
||||||
|
self.compression_type = compression_type
|
||||||
|
self.reserved_data = reserved_data
|
||||||
|
|
||||||
|
def as_bytes(self) -> bytes:
|
||||||
|
return bytes([
|
||||||
|
(self.protocol_version << 4) | self.header_size,
|
||||||
|
(self.message_type << 4) | self.message_type_specific_flags,
|
||||||
|
(self.serial_method << 4) | self.compression_type,
|
||||||
|
self.reserved_data
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class Optional:
|
||||||
|
def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None):
|
||||||
|
self.event = event
|
||||||
|
self.sessionId = sessionId
|
||||||
|
self.errorCode = 0
|
||||||
|
self.connectionId = None
|
||||||
|
self.response_meta_json = None
|
||||||
|
self.sequence = sequence
|
||||||
|
|
||||||
|
def as_bytes(self) -> bytes:
|
||||||
|
option_bytes = bytearray()
|
||||||
|
if self.event != EVENT_NONE:
|
||||||
|
option_bytes.extend(self.event.to_bytes(4, "big", signed=True))
|
||||||
|
if self.sessionId is not None:
|
||||||
|
session_id_bytes = str.encode(self.sessionId)
|
||||||
|
size = len(session_id_bytes).to_bytes(4, "big", signed=True)
|
||||||
|
option_bytes.extend(size)
|
||||||
|
option_bytes.extend(session_id_bytes)
|
||||||
|
if self.sequence is not None:
|
||||||
|
option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True))
|
||||||
|
return option_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
def __init__(self, header, optional):
|
||||||
|
self.optional = optional
|
||||||
|
self.header = header
|
||||||
|
self.payload = None
|
||||||
|
self.payload_json = None
|
||||||
|
|
||||||
|
|
||||||
|
# 工具函数保持不变...
|
||||||
|
def gen_log_id():
|
||||||
|
"""生成logID"""
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
r = fastrand.pcg32bounded(1 << 24) + (1 << 20)
|
||||||
|
local_ip = "00000000000000000000000000000000"
|
||||||
|
return f"02{ts}{local_ip}{r:08x}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3',
|
||||||
|
audio_sample_rate=24000):
|
||||||
|
return str.encode(json.dumps({
|
||||||
|
"user": {"uid": uid},
|
||||||
|
"event": event,
|
||||||
|
"namespace": "BidirectionalTTS",
|
||||||
|
"req_params": {
|
||||||
|
"text": text,
|
||||||
|
"speaker": speaker,
|
||||||
|
"audio_params": {
|
||||||
|
"format": audio_format,
|
||||||
|
"sample_rate": audio_sample_rate,
|
||||||
|
"enable_timestamp": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def read_res_content(res, offset):
|
||||||
|
content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
|
||||||
|
offset += 4
|
||||||
|
content = str(res[offset: offset + content_size], encoding='utf8')
|
||||||
|
offset += content_size
|
||||||
|
return content, offset
|
||||||
|
|
||||||
|
|
||||||
|
def read_res_payload(res, offset):
|
||||||
|
payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
|
||||||
|
offset += 4
|
||||||
|
payload = res[offset: offset + payload_size]
|
||||||
|
offset += payload_size
|
||||||
|
return payload, offset
|
||||||
|
|
||||||
|
|
||||||
|
def parser_response(res) -> Response:
|
||||||
|
if isinstance(res, str):
|
||||||
|
raise RuntimeError(res)
|
||||||
|
response = Response(Header(), Optional())
|
||||||
|
# 解析结果
|
||||||
|
header = response.header
|
||||||
|
num = 0b00001111
|
||||||
|
header.protocol_version = res[0] >> 4 & num
|
||||||
|
header.header_size = res[0] & 0x0f
|
||||||
|
header.message_type = (res[1] >> 4) & num
|
||||||
|
header.message_type_specific_flags = res[1] & 0x0f
|
||||||
|
header.serial_method = res[2] >> num
|
||||||
|
header.message_compression = res[2] & 0x0f
|
||||||
|
header.reserved_data = res[3]
|
||||||
|
|
||||||
|
offset = 4
|
||||||
|
optional = response.optional
|
||||||
|
if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE:
|
||||||
|
# read event
|
||||||
|
if header.message_type_specific_flags == MsgTypeFlagWithEvent:
|
||||||
|
optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True)
|
||||||
|
offset += 4
|
||||||
|
if optional.event == EVENT_NONE:
|
||||||
|
return response
|
||||||
|
# read connectionId
|
||||||
|
elif optional.event == EVENT_ConnectionStarted:
|
||||||
|
optional.connectionId, offset = read_res_content(res, offset)
|
||||||
|
elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished:
|
||||||
|
optional.sessionId, offset = read_res_content(res, offset)
|
||||||
|
optional.response_meta_json, offset = read_res_content(res, offset)
|
||||||
|
elif optional.event == EVENT_TTSResponse:
|
||||||
|
optional.sessionId, offset = read_res_content(res, offset)
|
||||||
|
response.payload, offset = read_res_payload(res, offset)
|
||||||
|
elif optional.event == EVENT_TTSSentenceEnd:
|
||||||
|
optional.sessionId, offset = read_res_content(res, offset)
|
||||||
|
response.payload_json, offset = read_res_content(res, offset)
|
||||||
|
|
||||||
|
elif header.message_type == ERROR_INFORMATION:
|
||||||
|
optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True)
|
||||||
|
offset += 4
|
||||||
|
response.payload, offset = read_res_payload(res, offset)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def send_event(ws, header, optional=None, payload=None):
|
||||||
|
full_client_request = bytearray(header)
|
||||||
|
if optional is not None:
|
||||||
|
full_client_request.extend(optional)
|
||||||
|
if payload is not None:
|
||||||
|
payload_size = len(payload).to_bytes(4, 'big', signed=True)
|
||||||
|
full_client_request.extend(payload_size)
|
||||||
|
full_client_request.extend(payload)
|
||||||
|
await ws.send(full_client_request)
|
||||||
|
|
||||||
|
|
||||||
|
# 修改:TTS状态管理类,添加消息ID和任务追踪
|
||||||
|
class TTSState:
|
||||||
|
def __init__(self, message_id: str):
|
||||||
|
self.message_id = message_id
|
||||||
|
self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None
|
||||||
|
self.session_id: OptionalType[str] = None
|
||||||
|
self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
|
|
||||||
|
# 全局状态管理
|
||||||
|
class TTSManager:
|
||||||
|
def __init__(self):
|
||||||
|
# WebSocket -> 消息ID -> TTS状态
|
||||||
|
self.connections: Dict[any, Dict[str, TTSState]] = {}
|
||||||
|
# 会话ID -> 消息ID 的映射,用于路由响应
|
||||||
|
self.session_to_message: Dict[str, str] = {}
|
||||||
|
# 消息ID -> WebSocket 的映射
|
||||||
|
self.message_to_websocket: Dict[str, any] = {}
|
||||||
|
|
||||||
|
def get_connection_states(self, websocket) -> Dict[str, TTSState]:
|
||||||
|
"""获取WebSocket连接的所有TTS状态"""
|
||||||
|
if websocket not in self.connections:
|
||||||
|
self.connections[websocket] = {}
|
||||||
|
return self.connections[websocket]
|
||||||
|
|
||||||
|
def add_tts_state(self, websocket, message_id: str) -> TTSState:
|
||||||
|
"""添加新的TTS状态"""
|
||||||
|
states = self.get_connection_states(websocket)
|
||||||
|
if message_id in states:
|
||||||
|
# 如果已存在,先清理旧的
|
||||||
|
self.cleanup_message_state(websocket, message_id)
|
||||||
|
|
||||||
|
tts_state = TTSState(message_id)
|
||||||
|
states[message_id] = tts_state
|
||||||
|
self.message_to_websocket[message_id] = websocket
|
||||||
|
return tts_state
|
||||||
|
|
||||||
|
def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]:
|
||||||
|
"""获取指定的TTS状态"""
|
||||||
|
states = self.get_connection_states(websocket)
|
||||||
|
return states.get(message_id)
|
||||||
|
|
||||||
|
def register_session(self, session_id: str, message_id: str):
|
||||||
|
"""注册会话ID和消息ID的映射"""
|
||||||
|
self.session_to_message[session_id] = message_id
|
||||||
|
|
||||||
|
def get_message_by_session(self, session_id: str) -> OptionalType[str]:
|
||||||
|
"""根据会话ID获取消息ID"""
|
||||||
|
return self.session_to_message.get(session_id)
|
||||||
|
|
||||||
|
def get_websocket_by_message(self, message_id: str):
|
||||||
|
"""根据消息ID获取WebSocket"""
|
||||||
|
return self.message_to_websocket.get(message_id)
|
||||||
|
|
||||||
|
def cleanup_message_state(self, websocket, message_id: str):
|
||||||
|
"""清理指定消息的状态"""
|
||||||
|
states = self.get_connection_states(websocket)
|
||||||
|
if message_id in states:
|
||||||
|
tts_state = states[message_id]
|
||||||
|
# 取消任务
|
||||||
|
if tts_state.task and not tts_state.task.done():
|
||||||
|
tts_state.task.cancel()
|
||||||
|
# 清理映射
|
||||||
|
if tts_state.session_id and tts_state.session_id in self.session_to_message:
|
||||||
|
del self.session_to_message[tts_state.session_id]
|
||||||
|
if message_id in self.message_to_websocket:
|
||||||
|
del self.message_to_websocket[message_id]
|
||||||
|
# 删除状态
|
||||||
|
del states[message_id]
|
||||||
|
|
||||||
|
def cleanup_connection(self, websocket):
|
||||||
|
"""清理整个连接的状态"""
|
||||||
|
if websocket in self.connections:
|
||||||
|
states = self.connections[websocket]
|
||||||
|
for message_id in list(states.keys()):
|
||||||
|
self.cleanup_message_state(websocket, message_id)
|
||||||
|
del self.connections[websocket]
|
||||||
|
|
||||||
|
|
||||||
|
# 全局TTS管理器实例
|
||||||
|
tts_manager = TTSManager()
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化独立的TTS连接
|
||||||
|
async def create_tts_connection() -> websockets.WebSocketServerProtocol:
|
||||||
|
"""创建独立的TTS连接"""
|
||||||
|
log_id = gen_log_id()
|
||||||
|
ws_header = {
|
||||||
|
"X-Api-App-Key": APP_ID,
|
||||||
|
"X-Api-Access-Key": TOKEN,
|
||||||
|
"X-Api-Resource-Id": 'volc.service_type.10029',
|
||||||
|
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||||
|
"X-Tt-Logid": log_id,
|
||||||
|
}
|
||||||
|
url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection'
|
||||||
|
volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000)
|
||||||
|
|
||||||
|
# 启动连接
|
||||||
|
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||||||
|
message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes()
|
||||||
|
optional = Optional(event=EVENT_Start_Connection).as_bytes()
|
||||||
|
payload = str.encode("{}")
|
||||||
|
await send_event(volc_ws, header, optional, payload)
|
||||||
|
|
||||||
|
# 等待连接确认
|
||||||
|
raw_data = await volc_ws.recv()
|
||||||
|
res = parser_response(raw_data)
|
||||||
|
if res.optional.event != EVENT_ConnectionStarted:
|
||||||
|
raise Exception("TTS连接失败")
|
||||||
|
|
||||||
|
return volc_ws
|
||||||
|
|
||||||
|
|
||||||
|
# 处理单个TTS任务
|
||||||
|
async def process_tts_task(websocket, message_id: str, text: str):
|
||||||
|
"""处理单个TTS任务(独立协程)"""
|
||||||
|
tts_state = None
|
||||||
|
try:
|
||||||
|
print(f"开始处理TTS任务 [{message_id}]: {text}")
|
||||||
|
|
||||||
|
# 获取TTS状态
|
||||||
|
tts_state = tts_manager.get_tts_state(websocket, message_id)
|
||||||
|
if not tts_state:
|
||||||
|
raise Exception(f"找不到TTS状态: {message_id}")
|
||||||
|
|
||||||
|
tts_state.is_processing = True
|
||||||
|
|
||||||
|
# 创建独立的TTS连接
|
||||||
|
tts_state.volc_ws = await create_tts_connection()
|
||||||
|
|
||||||
|
# 创建会话
|
||||||
|
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
|
||||||
|
tts_manager.register_session(tts_state.session_id, message_id)
|
||||||
|
|
||||||
|
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
|
||||||
|
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||||||
|
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||||||
|
serial_method=JSON).as_bytes()
|
||||||
|
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
|
||||||
|
payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER)
|
||||||
|
await send_event(tts_state.volc_ws, header, optional, payload)
|
||||||
|
|
||||||
|
raw_data = await tts_state.volc_ws.recv()
|
||||||
|
res = parser_response(raw_data)
|
||||||
|
if res.optional.event != EVENT_SessionStarted:
|
||||||
|
raise Exception("TTS会话启动失败")
|
||||||
|
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
|
||||||
|
|
||||||
|
# 发送文本到TTS服务
|
||||||
|
print(f"发送文本到TTS服务 [{message_id}]...")
|
||||||
|
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||||||
|
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||||||
|
serial_method=JSON).as_bytes()
|
||||||
|
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
|
||||||
|
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER)
|
||||||
|
await send_event(tts_state.volc_ws, header, optional, payload)
|
||||||
|
|
||||||
|
# 接收TTS响应并发送到前端
|
||||||
|
print(f"开始接收TTS响应 [{message_id}]...")
|
||||||
|
audio_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_data = await asyncio.wait_for(
|
||||||
|
tts_state.volc_ws.recv(),
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
res = parser_response(raw_data)
|
||||||
|
|
||||||
|
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
|
||||||
|
|
||||||
|
if res.optional.event == EVENT_TTSSentenceEnd:
|
||||||
|
print(f"句子结束事件 [{message_id}] - 直接完成")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif res.optional.event == EVENT_SessionFinished:
|
||||||
|
print(f"收到会话结束事件 [{message_id}]")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif res.optional.event == EVENT_TTSResponse:
|
||||||
|
audio_count += 1
|
||||||
|
print(f"发送音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
|
||||||
|
# 发送音频数据,包含消息ID
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_audio_data",
|
||||||
|
"messageId": message_id,
|
||||||
|
"audioData": res.payload.hex() # 转为hex字符串
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print(f"TTS响应超时 [{message_id}],强制结束")
|
||||||
|
|
||||||
|
# 发送完成消息
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_audio_complete",
|
||||||
|
"messageId": message_id
|
||||||
|
})
|
||||||
|
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print(f"TTS任务被取消 [{message_id}]")
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_error",
|
||||||
|
"messageId": message_id,
|
||||||
|
"message": "TTS任务被取消"
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TTS处理异常 [{message_id}]: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_error",
|
||||||
|
"messageId": message_id,
|
||||||
|
"message": f"TTS处理失败: {str(e)}"
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if tts_state:
|
||||||
|
tts_state.is_processing = False
|
||||||
|
if tts_state.volc_ws:
|
||||||
|
try:
|
||||||
|
await tts_state.volc_ws.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
# 清理状态
|
||||||
|
tts_manager.cleanup_message_state(websocket, message_id)
|
||||||
|
|
||||||
|
|
||||||
|
# 启动TTS文本转换
|
||||||
|
async def handle_tts_text(websocket, message_id: str, text: str):
|
||||||
|
"""启动TTS文本转换"""
|
||||||
|
# 创建新的TTS状态
|
||||||
|
tts_state = tts_manager.add_tts_state(websocket, message_id)
|
||||||
|
|
||||||
|
# 启动异步任务
|
||||||
|
tts_state.task = asyncio.create_task(
|
||||||
|
process_tts_task(websocket, message_id, text)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 取消TTS任务
|
||||||
|
async def handle_tts_cancel(websocket, message_id: str):
|
||||||
|
"""取消TTS任务"""
|
||||||
|
tts_state = tts_manager.get_tts_state(websocket, message_id)
|
||||||
|
if tts_state and tts_state.task and not tts_state.task.done():
|
||||||
|
tts_state.task.cancel()
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_complete",
|
||||||
|
"messageId": message_id
|
||||||
|
})
|
||||||
|
tts_manager.cleanup_message_state(websocket, message_id)
|
||||||
|
|
||||||
|
|
||||||
|
# 清理连接的所有TTS资源
|
||||||
|
async def cleanup_connection_tts(websocket):
|
||||||
|
"""清理连接的所有TTS资源"""
|
||||||
|
print(f"清理连接的TTS资源...")
|
||||||
|
tts_manager.cleanup_connection(websocket)
|
||||||
|
print("TTS资源清理完成")
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
|
# websocket_service.py
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from typing import Set
|
from typing import Set
|
||||||
from aip import AipSpeech
|
from aip import AipSpeech
|
||||||
from app.constants.asr import APP_ID, API_KEY, SECRET_KEY
|
from app.constants.asr import APP_ID, API_KEY, SECRET_KEY
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
# 导入修改后的TTS模块
|
||||||
|
from . import tts
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
active_connections: Set[WebSocket] = set()
|
active_connections: Set[WebSocket] = set()
|
||||||
|
|
||||||
asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
|
asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
|
||||||
|
|
||||||
|
|
||||||
async def asr_buffer(buffer_data: bytes) -> str:
|
async def asr_buffer(buffer_data: bytes) -> str:
|
||||||
result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537})
|
result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537})
|
||||||
if result.get('err_msg') == 'success.':
|
if result.get('err_msg') == 'success.':
|
||||||
@@ -16,6 +21,7 @@ async def asr_buffer(buffer_data: bytes) -> str:
|
|||||||
else:
|
else:
|
||||||
return '语音转换失败'
|
return '语音转换失败'
|
||||||
|
|
||||||
|
|
||||||
async def broadcast_online_count():
|
async def broadcast_online_count():
|
||||||
data = {"online_count": len(active_connections), 'type': 'count'}
|
data = {"online_count": len(active_connections), 'type': 'count'}
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
@@ -27,12 +33,14 @@ async def broadcast_online_count():
|
|||||||
for ws in to_remove:
|
for ws in to_remove:
|
||||||
active_connections.remove(ws)
|
active_connections.remove(ws)
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/websocket")
|
@router.websocket("/websocket")
|
||||||
async def websocket_online_count(websocket: WebSocket):
|
async def websocket_online_count(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
active_connections.add(websocket)
|
active_connections.add(websocket)
|
||||||
await broadcast_online_count()
|
await broadcast_online_count()
|
||||||
temp_buffer = bytes()
|
temp_buffer = bytes()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
message = await websocket.receive()
|
message = await websocket.receive()
|
||||||
@@ -45,15 +53,59 @@ async def websocket_online_count(websocket: WebSocket):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
|
|
||||||
if msg_type == "ping":
|
if msg_type == "ping":
|
||||||
await websocket.send_json({"online_count": len(active_connections), "type": "count"})
|
await websocket.send_json({"online_count": len(active_connections), "type": "count"})
|
||||||
|
|
||||||
elif msg_type == "asr_end":
|
elif msg_type == "asr_end":
|
||||||
asr_text = await asr_buffer(temp_buffer)
|
asr_text = await asr_buffer(temp_buffer)
|
||||||
await websocket.send_json({"type": "asr_result", "result": asr_text})
|
await websocket.send_json({"type": "asr_result", "result": asr_text})
|
||||||
temp_buffer = bytes()
|
temp_buffer = bytes()
|
||||||
|
|
||||||
|
# 修改:TTS处理支持消息ID
|
||||||
|
elif msg_type == "tts_text":
|
||||||
|
message_id = data.get("messageId")
|
||||||
|
text = data.get("text", "")
|
||||||
|
|
||||||
|
if not message_id:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_error",
|
||||||
|
"message": "缺少messageId参数"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"收到TTS文本请求 [{message_id}]: {text}")
|
||||||
|
try:
|
||||||
|
await tts.handle_tts_text(websocket, message_id, text)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TTS文本处理异常 [{message_id}]: {e}")
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "tts_error",
|
||||||
|
"messageId": message_id,
|
||||||
|
"message": f"TTS处理失败: {str(e)}"
|
||||||
|
})
|
||||||
|
|
||||||
|
elif msg_type == "tts_cancel":
|
||||||
|
message_id = data.get("messageId")
|
||||||
|
if message_id:
|
||||||
|
print(f"收到TTS取消请求 [{message_id}]")
|
||||||
|
try:
|
||||||
|
await tts.handle_tts_cancel(websocket, message_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"TTS取消处理异常 [{message_id}]: {e}")
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
active_connections.remove(websocket)
|
pass
|
||||||
await broadcast_online_count()
|
except Exception as e:
|
||||||
except Exception:
|
print(f"WebSocket异常: {e}")
|
||||||
active_connections.remove(websocket)
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
active_connections.discard(websocket)
|
||||||
|
|
||||||
|
# 清理所有TTS资源
|
||||||
|
try:
|
||||||
|
await tts.cleanup_connection_tts(websocket)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
await broadcast_online_count()
|
await broadcast_online_count()
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
appId = '1142362958'
|
# APP_ID = '1142362958'
|
||||||
token = 'O-a4JkyLFrYkME9no11DxFkOY-UnAoFF'
|
# TOKEN = 'O-a4JkyLFrYkME9no11DxFkOY-UnAoFF'
|
||||||
speaker = 'zh_male_beijingxiaoye_moon_bigtts'
|
# SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
|
||||||
|
|
||||||
|
APP_ID = '2138450044'
|
||||||
|
TOKEN = 'V04_QumeQZhJrQ_In1Z0VBQm7n0ttMNO'
|
||||||
|
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
|
||||||
2
web/components.d.ts
vendored
2
web/components.d.ts
vendored
@@ -18,10 +18,12 @@ declare module 'vue' {
|
|||||||
NInput: typeof import('naive-ui')['NInput']
|
NInput: typeof import('naive-ui')['NInput']
|
||||||
NMessageProvider: typeof import('naive-ui')['NMessageProvider']
|
NMessageProvider: typeof import('naive-ui')['NMessageProvider']
|
||||||
NPopconfirm: typeof import('naive-ui')['NPopconfirm']
|
NPopconfirm: typeof import('naive-ui')['NPopconfirm']
|
||||||
|
NPopover: typeof import('naive-ui')['NPopover']
|
||||||
NScrollbar: typeof import('naive-ui')['NScrollbar']
|
NScrollbar: typeof import('naive-ui')['NScrollbar']
|
||||||
NSelect: typeof import('naive-ui')['NSelect']
|
NSelect: typeof import('naive-ui')['NSelect']
|
||||||
NTag: typeof import('naive-ui')['NTag']
|
NTag: typeof import('naive-ui')['NTag']
|
||||||
RouterLink: typeof import('vue-router')['RouterLink']
|
RouterLink: typeof import('vue-router')['RouterLink']
|
||||||
RouterView: typeof import('vue-router')['RouterView']
|
RouterView: typeof import('vue-router')['RouterView']
|
||||||
|
Tts: typeof import('./src/components/tts.vue')['default']
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ export default antfu({
|
|||||||
"antfu/top-level-function": "off",
|
"antfu/top-level-function": "off",
|
||||||
"ts/no-unsafe-function-type": "off",
|
"ts/no-unsafe-function-type": "off",
|
||||||
"no-console": "off",
|
"no-console": "off",
|
||||||
"unused-imports/no-unused-vars": "warn"
|
"unused-imports/no-unused-vars": "warn",
|
||||||
|
"ts/no-use-before-define": "off"
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ export { default as ChevronLeftIcon } from "./svg/heroicons/ChevronLeftIcon.svg?
|
|||||||
export { default as ExclamationTriangleIcon } from "./svg/heroicons/ExclamationTriangleIcon.svg?component";
|
export { default as ExclamationTriangleIcon } from "./svg/heroicons/ExclamationTriangleIcon.svg?component";
|
||||||
export { default as microphone } from "./svg/heroicons/MicrophoneIcon.svg?component";
|
export { default as microphone } from "./svg/heroicons/MicrophoneIcon.svg?component";
|
||||||
export { default as PaperAirplaneIcon } from "./svg/heroicons/PaperAirplaneIcon.svg?component";
|
export { default as PaperAirplaneIcon } from "./svg/heroicons/PaperAirplaneIcon.svg?component";
|
||||||
|
export { default as SpeakerWaveIcon } from "./svg/heroicons/SpeakerWaveIcon.svg?component";
|
||||||
export { default as TrashIcon } from "./svg/heroicons/TrashIcon.svg?component";
|
export { default as TrashIcon } from "./svg/heroicons/TrashIcon.svg?component";
|
||||||
|
|||||||
3
web/src/assets/Icons/svg/heroicons/SpeakerWaveIcon.svg
Normal file
3
web/src/assets/Icons/svg/heroicons/SpeakerWaveIcon.svg
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="size-6">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" d="M19.114 5.636a9 9 0 0 1 0 12.728M16.463 8.288a5.25 5.25 0 0 1 0 7.424M6.75 8.25l4.72-4.72a.75.75 0 0 1 1.28.53v15.88a.75.75 0 0 1-1.28.53l-4.72-4.72H4.51c-.88 0-1.704-.507-1.938-1.354A9.009 9.009 0 0 1 2.25 12c0-.83.112-1.633.322-2.396C2.806 8.756 3.63 8.25 4.51 8.25H6.75Z" />
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 472 B |
81
web/src/components/tts.vue
Normal file
81
web/src/components/tts.vue
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { SpeakerWaveIcon } from "@/assets/Icons";
|
||||||
|
import { useLayoutStore, useTtsStore } from "@/stores";
|
||||||
|
|
||||||
|
const { text, messageId } = defineProps<{
|
||||||
|
text: string;
|
||||||
|
messageId: string;
|
||||||
|
}>();
|
||||||
|
|
||||||
|
const ttsStore = useTtsStore();
|
||||||
|
const layoutStore = useLayoutStore();
|
||||||
|
const { simpleMode } = storeToRefs(layoutStore);
|
||||||
|
|
||||||
|
// 获取当前消息的状态
|
||||||
|
const isPlaying = computed(() => ttsStore.isPlaying(messageId));
|
||||||
|
const isLoading = computed(() => ttsStore.isLoading(messageId));
|
||||||
|
const hasAudio = computed(() => ttsStore.hasAudio(messageId));
|
||||||
|
|
||||||
|
// 处理按钮点击
|
||||||
|
const handleClick = () => {
|
||||||
|
if (isLoading.value) {
|
||||||
|
return; // 合成中不响应点击
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasAudio.value) {
|
||||||
|
// 如果音频已准备好,切换播放/暂停
|
||||||
|
if (isPlaying.value) {
|
||||||
|
ttsStore.pause(messageId);
|
||||||
|
} else {
|
||||||
|
ttsStore.play(messageId);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 如果没有音频,开始TTS转换
|
||||||
|
ttsStore.convertText(text, messageId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// 当文本改变时清理之前的音频
|
||||||
|
watch(
|
||||||
|
() => text,
|
||||||
|
() => {
|
||||||
|
ttsStore.clearAudio(messageId);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
onUnmounted(() => {
|
||||||
|
ttsStore.clearAudio(messageId);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<NPopover trigger="hover">
|
||||||
|
<template #trigger>
|
||||||
|
<NButton
|
||||||
|
:loading="isLoading"
|
||||||
|
@click="handleClick"
|
||||||
|
quaternary
|
||||||
|
circle
|
||||||
|
:disabled="!text.trim()"
|
||||||
|
>
|
||||||
|
<SpeakerWaveIcon
|
||||||
|
v-if="!isLoading"
|
||||||
|
class="!w-4 !h-4"
|
||||||
|
:class="{
|
||||||
|
'': !simpleMode,
|
||||||
|
'animate-pulse': isPlaying
|
||||||
|
}"
|
||||||
|
/>
|
||||||
|
</NButton>
|
||||||
|
</template>
|
||||||
|
<span>
|
||||||
|
{{
|
||||||
|
isLoading
|
||||||
|
? "合成中..."
|
||||||
|
: isPlaying
|
||||||
|
? "点击暂停"
|
||||||
|
: hasAudio
|
||||||
|
? "点击播放"
|
||||||
|
: "语音合成"
|
||||||
|
}}
|
||||||
|
</span>
|
||||||
|
</NPopover>
|
||||||
|
</template>
|
||||||
@@ -11,6 +11,7 @@ export interface Message {
|
|||||||
thinking?: string;
|
thinking?: string;
|
||||||
role?: string;
|
role?: string;
|
||||||
usage?: UsageInfo;
|
usage?: UsageInfo;
|
||||||
|
id?: string;
|
||||||
[property: string]: any;
|
[property: string]: any;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const router = createRouter({
|
|||||||
name: "community",
|
name: "community",
|
||||||
component: community,
|
component: community,
|
||||||
meta: {
|
meta: {
|
||||||
title: "社区"
|
title: "对话"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
30
web/src/services/audio_websocket.ts
Normal file
30
web/src/services/audio_websocket.ts
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import { useWebSocketStore } from "@/services";
|
||||||
|
|
||||||
|
export const useAudioWebSocket = () => {
|
||||||
|
const webSocketStore = useWebSocketStore();
|
||||||
|
|
||||||
|
const sendMessage = (data: string | Uint8Array) => {
|
||||||
|
if (webSocketStore.connected) {
|
||||||
|
if (typeof data === "string") {
|
||||||
|
webSocketStore.send(data);
|
||||||
|
} else {
|
||||||
|
webSocketStore.websocket?.send(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const ensureConnection = async (): Promise<void> => {
|
||||||
|
if (!webSocketStore.connected) {
|
||||||
|
webSocketStore.connect();
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const check = () => {
|
||||||
|
if (webSocketStore.connected) resolve();
|
||||||
|
else setTimeout(check, 100);
|
||||||
|
};
|
||||||
|
check();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return { sendMessage, ensureConnection };
|
||||||
|
};
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
export * from "./audio_websocket";
|
||||||
export * from "./base_service";
|
export * from "./base_service";
|
||||||
export * from "./chat_service";
|
export * from "./chat_service";
|
||||||
export * from "./websocket";
|
export * from "./websocket";
|
||||||
|
|||||||
@@ -1,28 +1,143 @@
|
|||||||
import { useChatStore } from "@/stores";
|
import { useChatStore, useTtsStore } from "@/stores";
|
||||||
|
|
||||||
// WebSocket
|
// WebSocket
|
||||||
export const useWebSocketStore = defineStore("websocket", () => {
|
export const useWebSocketStore = defineStore("websocket", () => {
|
||||||
const websocket = ref<WebSocket>();
|
const websocket = ref<WebSocket>();
|
||||||
const connected = ref(false);
|
const connected = ref(false);
|
||||||
const chatStore = useChatStore();
|
const chatStore = useChatStore();
|
||||||
|
const ttsStore = useTtsStore();
|
||||||
|
|
||||||
const { onlineCount } = storeToRefs(chatStore);
|
const { onlineCount } = storeToRefs(chatStore);
|
||||||
|
|
||||||
const onmessage = (e: MessageEvent) => {
|
const onmessage = (e: MessageEvent) => {
|
||||||
const data = JSON.parse(e.data);
|
// 检查消息类型
|
||||||
switch (data.type) {
|
if (e.data instanceof ArrayBuffer) {
|
||||||
case "count":
|
// 处理二进制音频数据(兜底处理,新版本应该不会用到)
|
||||||
onlineCount.value = data.online_count;
|
console.log("收到二进制音频数据,大小:", e.data.byteLength);
|
||||||
break;
|
console.warn("收到旧格式的二进制数据,无法确定messageId");
|
||||||
case "asr_result":
|
// 可以选择忽略或者作为兜底处理
|
||||||
chatStore.addMessageToHistory(data.result);
|
} else if (e.data instanceof Blob) {
|
||||||
|
// 如果是Blob,转换为ArrayBuffer(兜底处理)
|
||||||
|
e.data.arrayBuffer().then((buffer: ArrayBuffer) => {
|
||||||
|
console.log("收到Blob音频数据,大小:", buffer.byteLength);
|
||||||
|
console.warn("收到旧格式的Blob数据,无法确定messageId");
|
||||||
|
});
|
||||||
|
} else if (typeof e.data === "string") {
|
||||||
|
// 处理文本JSON消息
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(e.data);
|
||||||
|
switch (data.type) {
|
||||||
|
case "count":
|
||||||
|
onlineCount.value = data.online_count;
|
||||||
|
break;
|
||||||
|
case "asr_result":
|
||||||
|
chatStore.addMessageToHistory(data.result);
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 新的TTS消息格式处理
|
||||||
|
case "tts_audio_data":
|
||||||
|
// 新的音频数据格式,包含messageId和hex格式的音频数据
|
||||||
|
if (data.messageId && data.audioData) {
|
||||||
|
console.log(
|
||||||
|
`收到TTS音频数据 [${data.messageId}],hex长度:`,
|
||||||
|
data.audioData.length
|
||||||
|
);
|
||||||
|
try {
|
||||||
|
// 将hex字符串转换为ArrayBuffer
|
||||||
|
const bytes = data.audioData
|
||||||
|
.match(/.{1,2}/g)
|
||||||
|
?.map((byte: string) => Number.parseInt(byte, 16));
|
||||||
|
if (bytes) {
|
||||||
|
const buffer = new Uint8Array(bytes).buffer;
|
||||||
|
console.log(
|
||||||
|
`转换后的音频数据大小 [${data.messageId}]:`,
|
||||||
|
buffer.byteLength
|
||||||
|
);
|
||||||
|
ttsStore.handleAudioData(buffer, data.messageId);
|
||||||
|
} else {
|
||||||
|
console.error(`音频数据格式错误 [${data.messageId}]`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`音频数据转换失败 [${data.messageId}]:`, error);
|
||||||
|
ttsStore.handleError(
|
||||||
|
`音频数据转换失败: ${error}`,
|
||||||
|
data.messageId
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error("tts_audio_data消息格式错误:", data);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tts_audio_complete":
|
||||||
|
// TTS音频传输完成
|
||||||
|
if (data.messageId) {
|
||||||
|
console.log(`TTS音频传输完成 [${data.messageId}]`);
|
||||||
|
ttsStore.finishConversion(data.messageId);
|
||||||
|
} else {
|
||||||
|
console.log("TTS音频传输完成(无messageId)");
|
||||||
|
// 兜底处理,可能是旧格式
|
||||||
|
ttsStore.finishConversion(data.messageId);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tts_complete":
|
||||||
|
// TTS会话结束
|
||||||
|
if (data.messageId) {
|
||||||
|
console.log(`TTS会话结束 [${data.messageId}]`);
|
||||||
|
// 可以添加额外的清理逻辑
|
||||||
|
} else {
|
||||||
|
console.log("TTS会话结束");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tts_error":
|
||||||
|
// TTS错误
|
||||||
|
if (data.messageId) {
|
||||||
|
console.error(`TTS错误 [${data.messageId}]:`, data.message);
|
||||||
|
ttsStore.handleError(data.message, data.messageId);
|
||||||
|
} else {
|
||||||
|
console.error("TTS错误:", data.message);
|
||||||
|
// 兜底处理,可能是旧格式
|
||||||
|
ttsStore.handleError(data.message, data.messageId || "unknown");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
// 保留旧的消息类型作为兜底处理
|
||||||
|
case "tts_audio_complete_legacy":
|
||||||
|
case "tts_complete_legacy":
|
||||||
|
case "tts_error_legacy":
|
||||||
|
console.log("收到旧格式TTS消息:", data.type);
|
||||||
|
// 可以选择处理或忽略
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
console.log("未知消息类型:", data.type, data);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("JSON解析错误:", error, "原始数据:", e.data);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.warn("收到未知格式的消息:", typeof e.data, e.data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const send = (data: string) => {
|
const send = (data: string) => {
|
||||||
if (websocket.value && websocket.value.readyState === WebSocket.OPEN)
|
if (websocket.value && websocket.value.readyState === WebSocket.OPEN) {
|
||||||
websocket.value?.send(data);
|
websocket.value?.send(data);
|
||||||
|
} else {
|
||||||
|
console.warn("WebSocket未连接,无法发送消息:", data);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const sendBinary = (data: ArrayBuffer | Uint8Array) => {
|
||||||
|
if (websocket.value && websocket.value.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.value?.send(data);
|
||||||
|
} else {
|
||||||
|
console.warn("WebSocket未连接,无法发送二进制数据");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const close = () => {
|
const close = () => {
|
||||||
websocket.value?.close();
|
websocket.value?.close();
|
||||||
};
|
};
|
||||||
@@ -33,11 +148,15 @@ export const useWebSocketStore = defineStore("websocket", () => {
|
|||||||
|
|
||||||
websocket.value.onopen = () => {
|
websocket.value.onopen = () => {
|
||||||
connected.value = true;
|
connected.value = true;
|
||||||
|
console.log("WebSocket连接成功");
|
||||||
|
|
||||||
let pingIntervalId: NodeJS.Timeout | undefined;
|
let pingIntervalId: NodeJS.Timeout | undefined;
|
||||||
|
|
||||||
if (pingIntervalId) clearInterval(pingIntervalId);
|
if (pingIntervalId) clearInterval(pingIntervalId);
|
||||||
pingIntervalId = setInterval(() => send("ping"), 30 * 1000);
|
pingIntervalId = setInterval(() => {
|
||||||
|
// 修改ping格式为JSON格式,与后端保持一致
|
||||||
|
send(JSON.stringify({ type: "ping" }));
|
||||||
|
}, 30 * 1000);
|
||||||
|
|
||||||
if (websocket.value) {
|
if (websocket.value) {
|
||||||
websocket.value.onmessage = onmessage;
|
websocket.value.onmessage = onmessage;
|
||||||
@@ -45,20 +164,28 @@ export const useWebSocketStore = defineStore("websocket", () => {
|
|||||||
websocket.value.onerror = (e: Event) => {
|
websocket.value.onerror = (e: Event) => {
|
||||||
console.error(`WebSocket错误:${(e as ErrorEvent).message}`);
|
console.error(`WebSocket错误:${(e as ErrorEvent).message}`);
|
||||||
};
|
};
|
||||||
websocket.value.onclose = () => {
|
|
||||||
|
websocket.value.onclose = (e: CloseEvent) => {
|
||||||
connected.value = false;
|
connected.value = false;
|
||||||
|
console.log(`WebSocket连接关闭: ${e.code} ${e.reason}`);
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
console.log("尝试重新连接WebSocket...");
|
||||||
connect(); // 尝试重新连接
|
connect(); // 尝试重新连接
|
||||||
}, 1000); // 1秒后重试连接
|
}, 1000); // 1秒后重试连接
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
websocket.value.onerror = (e: Event) => {
|
||||||
|
console.error("WebSocket连接错误:", e);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
websocket,
|
websocket,
|
||||||
connected,
|
connected,
|
||||||
send,
|
send,
|
||||||
|
sendBinary,
|
||||||
close,
|
close,
|
||||||
connect
|
connect
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -118,7 +118,10 @@ export const useChatStore = defineStore("chat", () => {
|
|||||||
historyMessages.value[historyMessages.value.length - 1].thinking =
|
historyMessages.value[historyMessages.value.length - 1].thinking =
|
||||||
thinkingContent;
|
thinkingContent;
|
||||||
}
|
}
|
||||||
);
|
).then(() => {
|
||||||
|
historyMessages.value[historyMessages.value.length - 1].id =
|
||||||
|
new Date().getTime().toString();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
export * from "./asr_store";
|
export * from "./asr_store";
|
||||||
export * from "./chat_store";
|
export * from "./chat_store";
|
||||||
export * from "./layout_store";
|
export * from "./layout_store";
|
||||||
|
export * from "./tts_store";
|
||||||
302
web/src/stores/tts_store.ts
Normal file
302
web/src/stores/tts_store.ts
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
import { useAudioWebSocket } from "@/services";
|
||||||
|
import { createAudioUrl, mergeAudioChunks } from "@/utils";
|
||||||
|
|
||||||
|
interface AudioState {
|
||||||
|
isPlaying: boolean;
|
||||||
|
isLoading: boolean;
|
||||||
|
audioElement: HTMLAudioElement | null;
|
||||||
|
audioUrl: string | null;
|
||||||
|
audioChunks: ArrayBuffer[];
|
||||||
|
hasError: boolean;
|
||||||
|
errorMessage: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useTtsStore = defineStore("tts", () => {
|
||||||
|
// 多音频状态管理 - 以消息ID为key
|
||||||
|
const audioStates = ref<Map<string, AudioState>>(new Map());
|
||||||
|
|
||||||
|
// 当前活跃的转换请求(保留用于兼容性)
|
||||||
|
const activeConversion = ref<string | null>(null);
|
||||||
|
|
||||||
|
// 会话状态
|
||||||
|
const hasActiveSession = ref(false);
|
||||||
|
|
||||||
|
// WebSocket连接
|
||||||
|
const { sendMessage, ensureConnection } = useAudioWebSocket();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取或创建音频状态
|
||||||
|
*/
|
||||||
|
const getAudioState = (messageId: string): AudioState => {
|
||||||
|
if (!audioStates.value.has(messageId)) {
|
||||||
|
audioStates.value.set(messageId, {
|
||||||
|
isPlaying: false,
|
||||||
|
isLoading: false,
|
||||||
|
audioElement: null,
|
||||||
|
audioUrl: null,
|
||||||
|
audioChunks: [],
|
||||||
|
hasError: false,
|
||||||
|
errorMessage: ""
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return audioStates.value.get(messageId)!;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送文本进行TTS转换
|
||||||
|
*/
|
||||||
|
const convertText = async (text: string, messageId: string) => {
|
||||||
|
try {
|
||||||
|
await ensureConnection();
|
||||||
|
|
||||||
|
// 暂停其他正在播放的音频
|
||||||
|
pauseAll();
|
||||||
|
|
||||||
|
// 获取当前消息的状态
|
||||||
|
const state = getAudioState(messageId);
|
||||||
|
|
||||||
|
// 清理之前的音频和错误状态
|
||||||
|
clearAudioState(state);
|
||||||
|
state.isLoading = true;
|
||||||
|
state.audioChunks = [];
|
||||||
|
|
||||||
|
// 设置当前活跃转换
|
||||||
|
activeConversion.value = messageId;
|
||||||
|
hasActiveSession.value = true;
|
||||||
|
|
||||||
|
// 发送文本到TTS服务
|
||||||
|
sendMessage(JSON.stringify({ type: "tts_text", text, messageId }));
|
||||||
|
} catch (error) {
|
||||||
|
handleError(`连接失败: ${error}`, messageId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理接收到的音频数据 - 修改为支持messageId参数
|
||||||
|
*/
|
||||||
|
const handleAudioData = (data: ArrayBuffer, messageId?: string) => {
|
||||||
|
// 如果传递了messageId就使用它,否则使用activeConversion
|
||||||
|
const targetMessageId = messageId || activeConversion.value;
|
||||||
|
if (!targetMessageId) {
|
||||||
|
console.warn("handleAudioData: 没有有效的messageId");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`接收音频数据 [${targetMessageId}],大小:`, data.byteLength);
|
||||||
|
const state = getAudioState(targetMessageId);
|
||||||
|
state.audioChunks.push(data);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 完成TTS转换,创建播放器并自动播放 - 修改为支持messageId参数
|
||||||
|
*/
|
||||||
|
const finishConversion = async (messageId?: string) => {
|
||||||
|
// 如果传递了messageId就使用它,否则使用activeConversion
|
||||||
|
const targetMessageId = messageId || activeConversion.value;
|
||||||
|
if (!targetMessageId) {
|
||||||
|
console.warn("finishConversion: 没有有效的messageId");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const state = getAudioState(targetMessageId);
|
||||||
|
console.log(
|
||||||
|
`完成TTS转换 [${targetMessageId}],音频片段数量:`,
|
||||||
|
state.audioChunks.length
|
||||||
|
);
|
||||||
|
|
||||||
|
if (state.audioChunks.length === 0) {
|
||||||
|
handleError("没有接收到音频数据", targetMessageId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 合并音频片段
|
||||||
|
const mergedAudio = mergeAudioChunks(state.audioChunks);
|
||||||
|
console.log(
|
||||||
|
`合并后音频大小 [${targetMessageId}]:`,
|
||||||
|
mergedAudio.byteLength
|
||||||
|
);
|
||||||
|
|
||||||
|
// 创建音频URL和元素
|
||||||
|
state.audioUrl = createAudioUrl(mergedAudio);
|
||||||
|
state.audioElement = new Audio(state.audioUrl);
|
||||||
|
|
||||||
|
// 设置音频事件
|
||||||
|
setupAudioEvents(state, targetMessageId);
|
||||||
|
|
||||||
|
state.isLoading = false;
|
||||||
|
|
||||||
|
// 清除activeConversion(如果是当前活跃的)
|
||||||
|
if (activeConversion.value === targetMessageId) {
|
||||||
|
activeConversion.value = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`TTS音频准备完成 [${targetMessageId}],开始自动播放`);
|
||||||
|
|
||||||
|
// 自动播放
|
||||||
|
await play(targetMessageId);
|
||||||
|
} catch (error) {
|
||||||
|
handleError(`音频处理失败: ${error}`, targetMessageId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置音频事件监听
|
||||||
|
*/
|
||||||
|
const setupAudioEvents = (state: AudioState, messageId: string) => {
|
||||||
|
if (!state.audioElement) return;
|
||||||
|
|
||||||
|
const audio = state.audioElement;
|
||||||
|
|
||||||
|
audio.addEventListener("ended", () => {
|
||||||
|
state.isPlaying = false;
|
||||||
|
console.log(`音频播放结束 [${messageId}]`);
|
||||||
|
});
|
||||||
|
|
||||||
|
audio.addEventListener("error", (e) => {
|
||||||
|
console.error(`音频播放错误 [${messageId}]:`, e);
|
||||||
|
handleError("音频播放失败", messageId);
|
||||||
|
});
|
||||||
|
|
||||||
|
audio.addEventListener("canplaythrough", () => {
|
||||||
|
console.log(`音频可以播放 [${messageId}]`);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 播放指定消息的音频
|
||||||
|
*/
|
||||||
|
const play = async (messageId: string) => {
|
||||||
|
const state = getAudioState(messageId);
|
||||||
|
|
||||||
|
if (!state.audioElement) {
|
||||||
|
handleError("音频未准备好", messageId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 暂停其他正在播放的音频
|
||||||
|
pauseAll(messageId);
|
||||||
|
|
||||||
|
await state.audioElement.play();
|
||||||
|
state.isPlaying = true;
|
||||||
|
state.hasError = false;
|
||||||
|
state.errorMessage = "";
|
||||||
|
console.log(`开始播放音频 [${messageId}]`);
|
||||||
|
} catch (error) {
|
||||||
|
handleError(`播放失败: ${error}`, messageId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 暂停指定消息的音频
|
||||||
|
*/
|
||||||
|
const pause = (messageId: string) => {
|
||||||
|
const state = getAudioState(messageId);
|
||||||
|
|
||||||
|
if (!state.audioElement) return;
|
||||||
|
|
||||||
|
state.audioElement.pause();
|
||||||
|
state.isPlaying = false;
|
||||||
|
console.log(`暂停音频 [${messageId}]`);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 暂停所有音频
|
||||||
|
*/
|
||||||
|
const pauseAll = (excludeMessageId?: string) => {
|
||||||
|
audioStates.value.forEach((state, messageId) => {
|
||||||
|
if (excludeMessageId && messageId === excludeMessageId) return;
|
||||||
|
if (state.isPlaying && state.audioElement) {
|
||||||
|
state.audioElement.pause();
|
||||||
|
state.isPlaying = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理TTS错误 - 修改为支持messageId参数
|
||||||
|
*/
|
||||||
|
const handleError = (errorMsg: string, messageId?: string) => {
|
||||||
|
// 如果传递了messageId就使用它,否则使用activeConversion
|
||||||
|
const targetMessageId = messageId || activeConversion.value;
|
||||||
|
if (!targetMessageId) {
|
||||||
|
console.error(`TTS错误 (无messageId): ${errorMsg}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.error(`TTS错误 [${targetMessageId}]: ${errorMsg}`);
|
||||||
|
const state = getAudioState(targetMessageId);
|
||||||
|
state.hasError = true;
|
||||||
|
state.errorMessage = errorMsg;
|
||||||
|
state.isLoading = false;
|
||||||
|
|
||||||
|
if (activeConversion.value === targetMessageId) {
|
||||||
|
activeConversion.value = null;
|
||||||
|
hasActiveSession.value = false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理指定消息的音频资源
|
||||||
|
*/
|
||||||
|
const clearAudio = (messageId: string) => {
|
||||||
|
const state = getAudioState(messageId);
|
||||||
|
clearAudioState(state);
|
||||||
|
audioStates.value.delete(messageId);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理音频状态
|
||||||
|
*/
|
||||||
|
const clearAudioState = (state: AudioState) => {
|
||||||
|
if (state.audioElement) {
|
||||||
|
state.audioElement.pause();
|
||||||
|
state.audioElement = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.audioUrl) {
|
||||||
|
URL.revokeObjectURL(state.audioUrl);
|
||||||
|
state.audioUrl = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
state.isPlaying = false;
|
||||||
|
state.audioChunks = [];
|
||||||
|
state.hasError = false;
|
||||||
|
state.errorMessage = "";
|
||||||
|
};
|
||||||
|
|
||||||
|
// 状态查询方法
|
||||||
|
const isPlaying = (messageId: string) => getAudioState(messageId).isPlaying;
|
||||||
|
const isLoading = (messageId: string) => getAudioState(messageId).isLoading;
|
||||||
|
const hasAudio = (messageId: string) =>
|
||||||
|
!!getAudioState(messageId).audioElement;
|
||||||
|
const hasError = (messageId: string) => getAudioState(messageId).hasError;
|
||||||
|
const getErrorMessage = (messageId: string) =>
|
||||||
|
getAudioState(messageId).errorMessage;
|
||||||
|
|
||||||
|
// 组件卸载时清理所有资源
|
||||||
|
onUnmounted(() => {
|
||||||
|
audioStates.value.forEach((state) => clearAudioState(state));
|
||||||
|
audioStates.value.clear();
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
// 状态查询方法
|
||||||
|
isPlaying,
|
||||||
|
isLoading,
|
||||||
|
hasAudio,
|
||||||
|
hasError,
|
||||||
|
getErrorMessage,
|
||||||
|
|
||||||
|
// 核心方法
|
||||||
|
convertText,
|
||||||
|
handleAudioData,
|
||||||
|
finishConversion,
|
||||||
|
play,
|
||||||
|
pause,
|
||||||
|
pauseAll,
|
||||||
|
clearAudio,
|
||||||
|
handleError
|
||||||
|
};
|
||||||
|
});
|
||||||
20
web/src/utils/audio.ts
Normal file
20
web/src/utils/audio.ts
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// 合并音频片段
|
||||||
|
export const mergeAudioChunks = (chunks: ArrayBuffer[]): Uint8Array => {
|
||||||
|
const totalLength = chunks.reduce((acc, chunk) => acc + chunk.byteLength, 0);
|
||||||
|
const merged = new Uint8Array(totalLength);
|
||||||
|
let offset = 0;
|
||||||
|
chunks.forEach((chunk) => {
|
||||||
|
merged.set(new Uint8Array(chunk), offset);
|
||||||
|
offset += chunk.byteLength;
|
||||||
|
});
|
||||||
|
return merged;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 创建音频播放URL
|
||||||
|
export const createAudioUrl = (
|
||||||
|
audioData: Uint8Array,
|
||||||
|
mimeType = "audio/mp3"
|
||||||
|
): string => {
|
||||||
|
const blob = new Blob([audioData as BlobPart], { type: mimeType });
|
||||||
|
return URL.createObjectURL(blob);
|
||||||
|
};
|
||||||
7
web/src/utils/format.ts
Normal file
7
web/src/utils/format.ts
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
export const formatTime = (seconds: number): string => {
|
||||||
|
if (Number.isNaN(seconds) || !Number.isFinite(seconds)) return "00:00";
|
||||||
|
|
||||||
|
const minutes = Math.floor(seconds / 60);
|
||||||
|
const secs = Math.floor(seconds % 60);
|
||||||
|
return `${minutes.toString().padStart(2, "0")}:${secs.toString().padStart(2, "0")}`;
|
||||||
|
};
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
|
export * from "./audio";
|
||||||
export * from "./context";
|
export * from "./context";
|
||||||
|
export * from "./format";
|
||||||
export * from "./media";
|
export * from "./media";
|
||||||
export * from "./pcm";
|
export * from "./pcm";
|
||||||
export * from "./title";
|
export * from "./title";
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import type { SelectGroupOption, SelectOption } from "naive-ui";
|
import type { SelectGroupOption, SelectOption } from "naive-ui";
|
||||||
|
import type { Message } from "@/interfaces";
|
||||||
import { throttle } from "lodash-es";
|
import { throttle } from "lodash-es";
|
||||||
import AIAvatar from "@/assets/ai_avatar.png";
|
import AIAvatar from "@/assets/ai_avatar.png";
|
||||||
import {
|
import {
|
||||||
@@ -25,17 +26,18 @@ const scrollbarRef = ref<HTMLElement | null>(null);
|
|||||||
const options = ref<Array<SelectGroupOption | SelectOption>>([]);
|
const options = ref<Array<SelectGroupOption | SelectOption>>([]);
|
||||||
// NCollapse 组件的折叠状态
|
// NCollapse 组件的折叠状态
|
||||||
const collapseActive = ref<string[]>(
|
const collapseActive = ref<string[]>(
|
||||||
historyMessages.value.map((_, idx) => String(idx))
|
historyMessages.value.map((msg, idx) => String(msg.id ?? idx))
|
||||||
);
|
);
|
||||||
|
|
||||||
const getName = (idx: number) => String(idx);
|
const getName = (msg: Message, idx: number) => String(msg.id ?? idx);
|
||||||
|
|
||||||
|
// TODO: bugfix: 未能正确展开
|
||||||
watch(
|
watch(
|
||||||
historyMessages,
|
historyMessages,
|
||||||
(newVal, oldVal) => {
|
(newVal, oldVal) => {
|
||||||
// 取所有name
|
// 取所有name
|
||||||
const newNames = newVal.map((_, idx) => getName(idx));
|
const newNames = newVal.map((msg, idx) => getName(msg, idx));
|
||||||
const oldNames = oldVal ? oldVal.map((_, idx) => getName(idx)) : [];
|
const oldNames = oldVal ? oldVal.map((msg, idx) => getName(msg, idx)) : [];
|
||||||
// 找出新增的name
|
// 找出新增的name
|
||||||
const addedNames = newNames.filter((name) => !oldNames.includes(name));
|
const addedNames = newNames.filter((name) => !oldNames.includes(name));
|
||||||
// 保留原有已展开项
|
// 保留原有已展开项
|
||||||
@@ -45,9 +47,10 @@ watch(
|
|||||||
// 新增的默认展开
|
// 新增的默认展开
|
||||||
collapseActive.value = [...currentActive, ...addedNames];
|
collapseActive.value = [...currentActive, ...addedNames];
|
||||||
},
|
},
|
||||||
{ deep: true }
|
{ immediate: true, deep: true }
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// 处理折叠项的点击事件,切换折叠状态
|
||||||
const handleItemHeaderClick = (name: string) => {
|
const handleItemHeaderClick = (name: string) => {
|
||||||
if (collapseActive.value.includes(name)) {
|
if (collapseActive.value.includes(name)) {
|
||||||
collapseActive.value = collapseActive.value.filter((n) => n !== name);
|
collapseActive.value = collapseActive.value.filter((n) => n !== name);
|
||||||
@@ -177,9 +180,15 @@ onMounted(() => {
|
|||||||
:expanded-names="collapseActive[idx]"
|
:expanded-names="collapseActive[idx]"
|
||||||
>
|
>
|
||||||
<NCollapseItem
|
<NCollapseItem
|
||||||
:title="thinking && idx === historyMessages.length - 1 ? '思考中...' : '已深度思考'"
|
:title="
|
||||||
:name="getName(idx)"
|
thinking && idx === historyMessages.length - 1
|
||||||
@item-header-click="() => handleItemHeaderClick(getName(idx))"
|
? '思考中...'
|
||||||
|
: '已深度思考'
|
||||||
|
"
|
||||||
|
:name="getName(msg, idx)"
|
||||||
|
@item-header-click="
|
||||||
|
() => handleItemHeaderClick(getName(msg, idx))
|
||||||
|
"
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
class="text-[#7A7A7A] mb-4 border-l-2 border-[#E5E5E5] ml-2 pl-2"
|
class="text-[#7A7A7A] mb-4 border-l-2 border-[#E5E5E5] ml-2 pl-2"
|
||||||
@@ -190,6 +199,9 @@ onMounted(() => {
|
|||||||
</NCollapse>
|
</NCollapse>
|
||||||
<!-- 内容↓ 思维链↑ -->
|
<!-- 内容↓ 思维链↑ -->
|
||||||
<markdown :content="msg.content || ''" />
|
<markdown :content="msg.content || ''" />
|
||||||
|
<div v-if="msg.role !== 'user'" class="mt-2">
|
||||||
|
<tts :text="msg.content || ''" :message-id="msg.id!" />
|
||||||
|
</div>
|
||||||
<NDivider />
|
<NDivider />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user