feat: tts语音生成

This commit is contained in:
2025-06-30 09:50:44 +08:00
parent 51e7239c71
commit 06e6b4a8c9
20 changed files with 1135 additions and 30 deletions

View File

@@ -0,0 +1,455 @@
# tts.py
import uuid
import websockets
import time
import fastrand
import json
import asyncio
from typing import Dict, Any, Optional as OptionalType
from app.constants.tts import APP_ID, TOKEN, SPEAKER
# 协议常量保持不变...
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_RESPONSE = 0b1011
FULL_SERVER_RESPONSE = 0b1001
ERROR_INFORMATION = 0b1111
MsgTypeFlagWithEvent = 0b100
NO_SERIALIZATION = 0b0000
JSON = 0b0001
COMPRESSION_NO = 0b0000
# 事件类型
EVENT_NONE = 0
EVENT_Start_Connection = 1
EVENT_FinishConnection = 2
EVENT_ConnectionStarted = 50
EVENT_StartSession = 100
EVENT_FinishSession = 102
EVENT_SessionStarted = 150
EVENT_SessionFinished = 152
EVENT_TaskRequest = 200
EVENT_TTSSentenceEnd = 351
EVENT_TTSResponse = 352
# 所有类定义保持不变...
class Header:
def __init__(self,
protocol_version=PROTOCOL_VERSION,
header_size=DEFAULT_HEADER_SIZE,
message_type=0,
message_type_specific_flags=0,
serial_method=NO_SERIALIZATION,
compression_type=COMPRESSION_NO,
reserved_data=0):
self.header_size = header_size
self.protocol_version = protocol_version
self.message_type = message_type
self.message_type_specific_flags = message_type_specific_flags
self.serial_method = serial_method
self.compression_type = compression_type
self.reserved_data = reserved_data
def as_bytes(self) -> bytes:
return bytes([
(self.protocol_version << 4) | self.header_size,
(self.message_type << 4) | self.message_type_specific_flags,
(self.serial_method << 4) | self.compression_type,
self.reserved_data
])
class Optional:
def __init__(self, event=EVENT_NONE, sessionId=None, sequence=None):
self.event = event
self.sessionId = sessionId
self.errorCode = 0
self.connectionId = None
self.response_meta_json = None
self.sequence = sequence
def as_bytes(self) -> bytes:
option_bytes = bytearray()
if self.event != EVENT_NONE:
option_bytes.extend(self.event.to_bytes(4, "big", signed=True))
if self.sessionId is not None:
session_id_bytes = str.encode(self.sessionId)
size = len(session_id_bytes).to_bytes(4, "big", signed=True)
option_bytes.extend(size)
option_bytes.extend(session_id_bytes)
if self.sequence is not None:
option_bytes.extend(self.sequence.to_bytes(4, "big", signed=True))
return option_bytes
class Response:
def __init__(self, header, optional):
self.optional = optional
self.header = header
self.payload = None
self.payload_json = None
# 工具函数保持不变...
def gen_log_id():
"""生成logID"""
ts = int(time.time() * 1000)
r = fastrand.pcg32bounded(1 << 24) + (1 << 20)
local_ip = "00000000000000000000000000000000"
return f"02{ts}{local_ip}{r:08x}"
def get_payload_bytes(uid='1234', event=EVENT_NONE, text='', speaker='', audio_format='mp3',
audio_sample_rate=24000):
return str.encode(json.dumps({
"user": {"uid": uid},
"event": event,
"namespace": "BidirectionalTTS",
"req_params": {
"text": text,
"speaker": speaker,
"audio_params": {
"format": audio_format,
"sample_rate": audio_sample_rate,
"enable_timestamp": True,
}
}
}))
def read_res_content(res, offset):
content_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
offset += 4
content = str(res[offset: offset + content_size], encoding='utf8')
offset += content_size
return content, offset
def read_res_payload(res, offset):
payload_size = int.from_bytes(res[offset: offset + 4], "big", signed=True)
offset += 4
payload = res[offset: offset + payload_size]
offset += payload_size
return payload, offset
def parser_response(res) -> Response:
if isinstance(res, str):
raise RuntimeError(res)
response = Response(Header(), Optional())
# 解析结果
header = response.header
num = 0b00001111
header.protocol_version = res[0] >> 4 & num
header.header_size = res[0] & 0x0f
header.message_type = (res[1] >> 4) & num
header.message_type_specific_flags = res[1] & 0x0f
header.serial_method = res[2] >> num
header.message_compression = res[2] & 0x0f
header.reserved_data = res[3]
offset = 4
optional = response.optional
if header.message_type == FULL_SERVER_RESPONSE or AUDIO_ONLY_RESPONSE:
# read event
if header.message_type_specific_flags == MsgTypeFlagWithEvent:
optional.event = int.from_bytes(res[offset:offset + 4], "big", signed=True)
offset += 4
if optional.event == EVENT_NONE:
return response
# read connectionId
elif optional.event == EVENT_ConnectionStarted:
optional.connectionId, offset = read_res_content(res, offset)
elif optional.event == EVENT_SessionStarted or optional.event == EVENT_SessionFinished:
optional.sessionId, offset = read_res_content(res, offset)
optional.response_meta_json, offset = read_res_content(res, offset)
elif optional.event == EVENT_TTSResponse:
optional.sessionId, offset = read_res_content(res, offset)
response.payload, offset = read_res_payload(res, offset)
elif optional.event == EVENT_TTSSentenceEnd:
optional.sessionId, offset = read_res_content(res, offset)
response.payload_json, offset = read_res_content(res, offset)
elif header.message_type == ERROR_INFORMATION:
optional.errorCode = int.from_bytes(res[offset:offset + 4], "big", signed=True)
offset += 4
response.payload, offset = read_res_payload(res, offset)
return response
async def send_event(ws, header, optional=None, payload=None):
full_client_request = bytearray(header)
if optional is not None:
full_client_request.extend(optional)
if payload is not None:
payload_size = len(payload).to_bytes(4, 'big', signed=True)
full_client_request.extend(payload_size)
full_client_request.extend(payload)
await ws.send(full_client_request)
# 修改TTS状态管理类添加消息ID和任务追踪
class TTSState:
def __init__(self, message_id: str):
self.message_id = message_id
self.volc_ws: OptionalType[websockets.WebSocketServerProtocol] = None
self.session_id: OptionalType[str] = None
self.task: OptionalType[asyncio.Task] = None # 用于追踪异步任务
self.is_processing = False
# 全局状态管理
class TTSManager:
def __init__(self):
# WebSocket -> 消息ID -> TTS状态
self.connections: Dict[any, Dict[str, TTSState]] = {}
# 会话ID -> 消息ID 的映射,用于路由响应
self.session_to_message: Dict[str, str] = {}
# 消息ID -> WebSocket 的映射
self.message_to_websocket: Dict[str, any] = {}
def get_connection_states(self, websocket) -> Dict[str, TTSState]:
"""获取WebSocket连接的所有TTS状态"""
if websocket not in self.connections:
self.connections[websocket] = {}
return self.connections[websocket]
def add_tts_state(self, websocket, message_id: str) -> TTSState:
"""添加新的TTS状态"""
states = self.get_connection_states(websocket)
if message_id in states:
# 如果已存在,先清理旧的
self.cleanup_message_state(websocket, message_id)
tts_state = TTSState(message_id)
states[message_id] = tts_state
self.message_to_websocket[message_id] = websocket
return tts_state
def get_tts_state(self, websocket, message_id: str) -> OptionalType[TTSState]:
"""获取指定的TTS状态"""
states = self.get_connection_states(websocket)
return states.get(message_id)
def register_session(self, session_id: str, message_id: str):
"""注册会话ID和消息ID的映射"""
self.session_to_message[session_id] = message_id
def get_message_by_session(self, session_id: str) -> OptionalType[str]:
"""根据会话ID获取消息ID"""
return self.session_to_message.get(session_id)
def get_websocket_by_message(self, message_id: str):
"""根据消息ID获取WebSocket"""
return self.message_to_websocket.get(message_id)
def cleanup_message_state(self, websocket, message_id: str):
"""清理指定消息的状态"""
states = self.get_connection_states(websocket)
if message_id in states:
tts_state = states[message_id]
# 取消任务
if tts_state.task and not tts_state.task.done():
tts_state.task.cancel()
# 清理映射
if tts_state.session_id and tts_state.session_id in self.session_to_message:
del self.session_to_message[tts_state.session_id]
if message_id in self.message_to_websocket:
del self.message_to_websocket[message_id]
# 删除状态
del states[message_id]
def cleanup_connection(self, websocket):
"""清理整个连接的状态"""
if websocket in self.connections:
states = self.connections[websocket]
for message_id in list(states.keys()):
self.cleanup_message_state(websocket, message_id)
del self.connections[websocket]
# 全局TTS管理器实例
tts_manager = TTSManager()
# 初始化独立的TTS连接
async def create_tts_connection() -> websockets.WebSocketServerProtocol:
"""创建独立的TTS连接"""
log_id = gen_log_id()
ws_header = {
"X-Api-App-Key": APP_ID,
"X-Api-Access-Key": TOKEN,
"X-Api-Resource-Id": 'volc.service_type.10029',
"X-Api-Connect-Id": str(uuid.uuid4()),
"X-Tt-Logid": log_id,
}
url = 'wss://openspeech.bytedance.com/api/v3/tts/bidirection'
volc_ws = await websockets.connect(url, additional_headers=ws_header, max_size=1000000000)
# 启动连接
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent).as_bytes()
optional = Optional(event=EVENT_Start_Connection).as_bytes()
payload = str.encode("{}")
await send_event(volc_ws, header, optional, payload)
# 等待连接确认
raw_data = await volc_ws.recv()
res = parser_response(raw_data)
if res.optional.event != EVENT_ConnectionStarted:
raise Exception("TTS连接失败")
return volc_ws
# 处理单个TTS任务
async def process_tts_task(websocket, message_id: str, text: str):
"""处理单个TTS任务独立协程"""
tts_state = None
try:
print(f"开始处理TTS任务 [{message_id}]: {text}")
# 获取TTS状态
tts_state = tts_manager.get_tts_state(websocket, message_id)
if not tts_state:
raise Exception(f"找不到TTS状态: {message_id}")
tts_state.is_processing = True
# 创建独立的TTS连接
tts_state.volc_ws = await create_tts_connection()
# 创建会话
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
tts_manager.register_session(tts_state.session_id, message_id)
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER)
await send_event(tts_state.volc_ws, header, optional, payload)
raw_data = await tts_state.volc_ws.recv()
res = parser_response(raw_data)
if res.optional.event != EVENT_SessionStarted:
raise Exception("TTS会话启动失败")
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
# 发送文本到TTS服务
print(f"发送文本到TTS服务 [{message_id}]...")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER)
await send_event(tts_state.volc_ws, header, optional, payload)
# 接收TTS响应并发送到前端
print(f"开始接收TTS响应 [{message_id}]...")
audio_count = 0
try:
while True:
raw_data = await asyncio.wait_for(
tts_state.volc_ws.recv(),
timeout=30
)
res = parser_response(raw_data)
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
if res.optional.event == EVENT_TTSSentenceEnd:
print(f"句子结束事件 [{message_id}] - 直接完成")
break
elif res.optional.event == EVENT_SessionFinished:
print(f"收到会话结束事件 [{message_id}]")
break
elif res.optional.event == EVENT_TTSResponse:
audio_count += 1
print(f"发送音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
# 发送音频数据包含消息ID
await websocket.send_json({
"type": "tts_audio_data",
"messageId": message_id,
"audioData": res.payload.hex() # 转为hex字符串
})
else:
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
except asyncio.TimeoutError:
print(f"TTS响应超时 [{message_id}],强制结束")
# 发送完成消息
await websocket.send_json({
"type": "tts_audio_complete",
"messageId": message_id
})
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包")
except asyncio.CancelledError:
print(f"TTS任务被取消 [{message_id}]")
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": "TTS任务被取消"
})
except Exception as e:
print(f"TTS处理异常 [{message_id}]: {e}")
import traceback
traceback.print_exc()
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": f"TTS处理失败: {str(e)}"
})
finally:
# 清理资源
if tts_state:
tts_state.is_processing = False
if tts_state.volc_ws:
try:
await tts_state.volc_ws.close()
except:
pass
# 清理状态
tts_manager.cleanup_message_state(websocket, message_id)
# 启动TTS文本转换
async def handle_tts_text(websocket, message_id: str, text: str):
"""启动TTS文本转换"""
# 创建新的TTS状态
tts_state = tts_manager.add_tts_state(websocket, message_id)
# 启动异步任务
tts_state.task = asyncio.create_task(
process_tts_task(websocket, message_id, text)
)
# 取消TTS任务
async def handle_tts_cancel(websocket, message_id: str):
"""取消TTS任务"""
tts_state = tts_manager.get_tts_state(websocket, message_id)
if tts_state and tts_state.task and not tts_state.task.done():
tts_state.task.cancel()
await websocket.send_json({
"type": "tts_complete",
"messageId": message_id
})
tts_manager.cleanup_message_state(websocket, message_id)
# 清理连接的所有TTS资源
async def cleanup_connection_tts(websocket):
"""清理连接的所有TTS资源"""
print(f"清理连接的TTS资源...")
tts_manager.cleanup_connection(websocket)
print("TTS资源清理完成")

View File

@@ -1,14 +1,19 @@
# websocket_service.py
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set from typing import Set
from aip import AipSpeech from aip import AipSpeech
from app.constants.asr import APP_ID, API_KEY, SECRET_KEY from app.constants.asr import APP_ID, API_KEY, SECRET_KEY
import json import json
# 导入修改后的TTS模块
from . import tts
router = APIRouter() router = APIRouter()
active_connections: Set[WebSocket] = set() active_connections: Set[WebSocket] = set()
asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY) asr_client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
async def asr_buffer(buffer_data: bytes) -> str: async def asr_buffer(buffer_data: bytes) -> str:
result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537}) result = asr_client.asr(buffer_data, 'pcm', 16000, {'dev_pid': 1537})
if result.get('err_msg') == 'success.': if result.get('err_msg') == 'success.':
@@ -16,6 +21,7 @@ async def asr_buffer(buffer_data: bytes) -> str:
else: else:
return '语音转换失败' return '语音转换失败'
async def broadcast_online_count(): async def broadcast_online_count():
data = {"online_count": len(active_connections), 'type': 'count'} data = {"online_count": len(active_connections), 'type': 'count'}
to_remove = set() to_remove = set()
@@ -27,12 +33,14 @@ async def broadcast_online_count():
for ws in to_remove: for ws in to_remove:
active_connections.remove(ws) active_connections.remove(ws)
@router.websocket("/websocket") @router.websocket("/websocket")
async def websocket_online_count(websocket: WebSocket): async def websocket_online_count(websocket: WebSocket):
await websocket.accept() await websocket.accept()
active_connections.add(websocket) active_connections.add(websocket)
await broadcast_online_count() await broadcast_online_count()
temp_buffer = bytes() temp_buffer = bytes()
try: try:
while True: while True:
message = await websocket.receive() message = await websocket.receive()
@@ -45,15 +53,59 @@ async def websocket_online_count(websocket: WebSocket):
except Exception: except Exception:
continue continue
msg_type = data.get("type") msg_type = data.get("type")
if msg_type == "ping": if msg_type == "ping":
await websocket.send_json({"online_count": len(active_connections), "type": "count"}) await websocket.send_json({"online_count": len(active_connections), "type": "count"})
elif msg_type == "asr_end": elif msg_type == "asr_end":
asr_text = await asr_buffer(temp_buffer) asr_text = await asr_buffer(temp_buffer)
await websocket.send_json({"type": "asr_result", "result": asr_text}) await websocket.send_json({"type": "asr_result", "result": asr_text})
temp_buffer = bytes() temp_buffer = bytes()
# 修改TTS处理支持消息ID
elif msg_type == "tts_text":
message_id = data.get("messageId")
text = data.get("text", "")
if not message_id:
await websocket.send_json({
"type": "tts_error",
"message": "缺少messageId参数"
})
continue
print(f"收到TTS文本请求 [{message_id}]: {text}")
try:
await tts.handle_tts_text(websocket, message_id, text)
except Exception as e:
print(f"TTS文本处理异常 [{message_id}]: {e}")
await websocket.send_json({
"type": "tts_error",
"messageId": message_id,
"message": f"TTS处理失败: {str(e)}"
})
elif msg_type == "tts_cancel":
message_id = data.get("messageId")
if message_id:
print(f"收到TTS取消请求 [{message_id}]")
try:
await tts.handle_tts_cancel(websocket, message_id)
except Exception as e:
print(f"TTS取消处理异常 [{message_id}]: {e}")
except WebSocketDisconnect: except WebSocketDisconnect:
active_connections.remove(websocket) pass
await broadcast_online_count() except Exception as e:
except Exception: print(f"WebSocket异常: {e}")
active_connections.remove(websocket) finally:
# 清理资源
active_connections.discard(websocket)
# 清理所有TTS资源
try:
await tts.cleanup_connection_tts(websocket)
except:
pass
await broadcast_online_count() await broadcast_online_count()

View File

@@ -1,3 +1,7 @@
appId = '1142362958' # APP_ID = '1142362958'
token = 'O-a4JkyLFrYkME9no11DxFkOY-UnAoFF' # TOKEN = 'O-a4JkyLFrYkME9no11DxFkOY-UnAoFF'
speaker = 'zh_male_beijingxiaoye_moon_bigtts' # SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
APP_ID = '2138450044'
TOKEN = 'V04_QumeQZhJrQ_In1Z0VBQm7n0ttMNO'
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'

2
web/components.d.ts vendored
View File

@@ -18,10 +18,12 @@ declare module 'vue' {
NInput: typeof import('naive-ui')['NInput'] NInput: typeof import('naive-ui')['NInput']
NMessageProvider: typeof import('naive-ui')['NMessageProvider'] NMessageProvider: typeof import('naive-ui')['NMessageProvider']
NPopconfirm: typeof import('naive-ui')['NPopconfirm'] NPopconfirm: typeof import('naive-ui')['NPopconfirm']
NPopover: typeof import('naive-ui')['NPopover']
NScrollbar: typeof import('naive-ui')['NScrollbar'] NScrollbar: typeof import('naive-ui')['NScrollbar']
NSelect: typeof import('naive-ui')['NSelect'] NSelect: typeof import('naive-ui')['NSelect']
NTag: typeof import('naive-ui')['NTag'] NTag: typeof import('naive-ui')['NTag']
RouterLink: typeof import('vue-router')['RouterLink'] RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView'] RouterView: typeof import('vue-router')['RouterView']
Tts: typeof import('./src/components/tts.vue')['default']
} }
} }

View File

@@ -31,6 +31,7 @@ export default antfu({
"antfu/top-level-function": "off", "antfu/top-level-function": "off",
"ts/no-unsafe-function-type": "off", "ts/no-unsafe-function-type": "off",
"no-console": "off", "no-console": "off",
"unused-imports/no-unused-vars": "warn" "unused-imports/no-unused-vars": "warn",
"ts/no-use-before-define": "off"
} }
}); });

View File

@@ -2,4 +2,5 @@ export { default as ChevronLeftIcon } from "./svg/heroicons/ChevronLeftIcon.svg?
export { default as ExclamationTriangleIcon } from "./svg/heroicons/ExclamationTriangleIcon.svg?component"; export { default as ExclamationTriangleIcon } from "./svg/heroicons/ExclamationTriangleIcon.svg?component";
export { default as microphone } from "./svg/heroicons/MicrophoneIcon.svg?component"; export { default as microphone } from "./svg/heroicons/MicrophoneIcon.svg?component";
export { default as PaperAirplaneIcon } from "./svg/heroicons/PaperAirplaneIcon.svg?component"; export { default as PaperAirplaneIcon } from "./svg/heroicons/PaperAirplaneIcon.svg?component";
export { default as SpeakerWaveIcon } from "./svg/heroicons/SpeakerWaveIcon.svg?component";
export { default as TrashIcon } from "./svg/heroicons/TrashIcon.svg?component"; export { default as TrashIcon } from "./svg/heroicons/TrashIcon.svg?component";

View File

@@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="size-6">
<path stroke-linecap="round" stroke-linejoin="round" d="M19.114 5.636a9 9 0 0 1 0 12.728M16.463 8.288a5.25 5.25 0 0 1 0 7.424M6.75 8.25l4.72-4.72a.75.75 0 0 1 1.28.53v15.88a.75.75 0 0 1-1.28.53l-4.72-4.72H4.51c-.88 0-1.704-.507-1.938-1.354A9.009 9.009 0 0 1 2.25 12c0-.83.112-1.633.322-2.396C2.806 8.756 3.63 8.25 4.51 8.25H6.75Z" />
</svg>

After

Width:  |  Height:  |  Size: 472 B

View File

@@ -0,0 +1,81 @@
<script setup lang="ts">
import { SpeakerWaveIcon } from "@/assets/Icons";
import { useLayoutStore, useTtsStore } from "@/stores";
const { text, messageId } = defineProps<{
text: string;
messageId: string;
}>();
const ttsStore = useTtsStore();
const layoutStore = useLayoutStore();
const { simpleMode } = storeToRefs(layoutStore);
// 获取当前消息的状态
const isPlaying = computed(() => ttsStore.isPlaying(messageId));
const isLoading = computed(() => ttsStore.isLoading(messageId));
const hasAudio = computed(() => ttsStore.hasAudio(messageId));
// 处理按钮点击
const handleClick = () => {
if (isLoading.value) {
return; // 合成中不响应点击
}
if (hasAudio.value) {
// 如果音频已准备好,切换播放/暂停
if (isPlaying.value) {
ttsStore.pause(messageId);
} else {
ttsStore.play(messageId);
}
} else {
// 如果没有音频开始TTS转换
ttsStore.convertText(text, messageId);
}
};
// 当文本改变时清理之前的音频
watch(
() => text,
() => {
ttsStore.clearAudio(messageId);
}
);
onUnmounted(() => {
ttsStore.clearAudio(messageId);
});
</script>
<template>
<NPopover trigger="hover">
<template #trigger>
<NButton
:loading="isLoading"
@click="handleClick"
quaternary
circle
:disabled="!text.trim()"
>
<SpeakerWaveIcon
v-if="!isLoading"
class="!w-4 !h-4"
:class="{
'': !simpleMode,
'animate-pulse': isPlaying
}"
/>
</NButton>
</template>
<span>
{{
isLoading
? "合成中..."
: isPlaying
? "点击暂停"
: hasAudio
? "点击播放"
: "语音合成"
}}
</span>
</NPopover>
</template>

View File

@@ -11,6 +11,7 @@ export interface Message {
thinking?: string; thinking?: string;
role?: string; role?: string;
usage?: UsageInfo; usage?: UsageInfo;
id?: string;
[property: string]: any; [property: string]: any;
} }

View File

@@ -16,7 +16,7 @@ const router = createRouter({
name: "community", name: "community",
component: community, component: community,
meta: { meta: {
title: "社区" title: "对话"
} }
} }
] ]

View File

@@ -0,0 +1,30 @@
import { useWebSocketStore } from "@/services";
export const useAudioWebSocket = () => {
const webSocketStore = useWebSocketStore();
const sendMessage = (data: string | Uint8Array) => {
if (webSocketStore.connected) {
if (typeof data === "string") {
webSocketStore.send(data);
} else {
webSocketStore.websocket?.send(data);
}
}
};
const ensureConnection = async (): Promise<void> => {
if (!webSocketStore.connected) {
webSocketStore.connect();
await new Promise<void>((resolve) => {
const check = () => {
if (webSocketStore.connected) resolve();
else setTimeout(check, 100);
};
check();
});
}
};
return { sendMessage, ensureConnection };
};

View File

@@ -1,3 +1,4 @@
export * from "./audio_websocket";
export * from "./base_service"; export * from "./base_service";
export * from "./chat_service"; export * from "./chat_service";
export * from "./websocket"; export * from "./websocket";

View File

@@ -1,28 +1,143 @@
import { useChatStore } from "@/stores"; import { useChatStore, useTtsStore } from "@/stores";
// WebSocket // WebSocket
export const useWebSocketStore = defineStore("websocket", () => { export const useWebSocketStore = defineStore("websocket", () => {
const websocket = ref<WebSocket>(); const websocket = ref<WebSocket>();
const connected = ref(false); const connected = ref(false);
const chatStore = useChatStore(); const chatStore = useChatStore();
const ttsStore = useTtsStore();
const { onlineCount } = storeToRefs(chatStore); const { onlineCount } = storeToRefs(chatStore);
const onmessage = (e: MessageEvent) => { const onmessage = (e: MessageEvent) => {
const data = JSON.parse(e.data); // 检查消息类型
switch (data.type) { if (e.data instanceof ArrayBuffer) {
case "count": // 处理二进制音频数据(兜底处理,新版本应该不会用到)
onlineCount.value = data.online_count; console.log("收到二进制音频数据,大小:", e.data.byteLength);
break; console.warn("收到旧格式的二进制数据无法确定messageId");
case "asr_result": // 可以选择忽略或者作为兜底处理
chatStore.addMessageToHistory(data.result); } else if (e.data instanceof Blob) {
// 如果是Blob转换为ArrayBuffer兜底处理
e.data.arrayBuffer().then((buffer: ArrayBuffer) => {
console.log("收到Blob音频数据大小:", buffer.byteLength);
console.warn("收到旧格式的Blob数据无法确定messageId");
});
} else if (typeof e.data === "string") {
// 处理文本JSON消息
try {
const data = JSON.parse(e.data);
switch (data.type) {
case "count":
onlineCount.value = data.online_count;
break;
case "asr_result":
chatStore.addMessageToHistory(data.result);
break;
// 新的TTS消息格式处理
case "tts_audio_data":
// 新的音频数据格式包含messageId和hex格式的音频数据
if (data.messageId && data.audioData) {
console.log(
`收到TTS音频数据 [${data.messageId}]hex长度:`,
data.audioData.length
);
try {
// 将hex字符串转换为ArrayBuffer
const bytes = data.audioData
.match(/.{1,2}/g)
?.map((byte: string) => Number.parseInt(byte, 16));
if (bytes) {
const buffer = new Uint8Array(bytes).buffer;
console.log(
`转换后的音频数据大小 [${data.messageId}]:`,
buffer.byteLength
);
ttsStore.handleAudioData(buffer, data.messageId);
} else {
console.error(`音频数据格式错误 [${data.messageId}]`);
}
} catch (error) {
console.error(`音频数据转换失败 [${data.messageId}]:`, error);
ttsStore.handleError(
`音频数据转换失败: ${error}`,
data.messageId
);
}
} else {
console.error("tts_audio_data消息格式错误:", data);
}
break;
case "tts_audio_complete":
// TTS音频传输完成
if (data.messageId) {
console.log(`TTS音频传输完成 [${data.messageId}]`);
ttsStore.finishConversion(data.messageId);
} else {
console.log("TTS音频传输完成无messageId");
// 兜底处理,可能是旧格式
ttsStore.finishConversion(data.messageId);
}
break;
case "tts_complete":
// TTS会话结束
if (data.messageId) {
console.log(`TTS会话结束 [${data.messageId}]`);
// 可以添加额外的清理逻辑
} else {
console.log("TTS会话结束");
}
break;
case "tts_error":
// TTS错误
if (data.messageId) {
console.error(`TTS错误 [${data.messageId}]:`, data.message);
ttsStore.handleError(data.message, data.messageId);
} else {
console.error("TTS错误:", data.message);
// 兜底处理,可能是旧格式
ttsStore.handleError(data.message, data.messageId || "unknown");
}
break;
// 保留旧的消息类型作为兜底处理
case "tts_audio_complete_legacy":
case "tts_complete_legacy":
case "tts_error_legacy":
console.log("收到旧格式TTS消息:", data.type);
// 可以选择处理或忽略
break;
default:
console.log("未知消息类型:", data.type, data);
}
} catch (error) {
console.error("JSON解析错误:", error, "原始数据:", e.data);
}
} else {
console.warn("收到未知格式的消息:", typeof e.data, e.data);
} }
}; };
const send = (data: string) => { const send = (data: string) => {
if (websocket.value && websocket.value.readyState === WebSocket.OPEN) if (websocket.value && websocket.value.readyState === WebSocket.OPEN) {
websocket.value?.send(data); websocket.value?.send(data);
} else {
console.warn("WebSocket未连接无法发送消息:", data);
}
}; };
const sendBinary = (data: ArrayBuffer | Uint8Array) => {
if (websocket.value && websocket.value.readyState === WebSocket.OPEN) {
websocket.value?.send(data);
} else {
console.warn("WebSocket未连接无法发送二进制数据");
}
};
const close = () => { const close = () => {
websocket.value?.close(); websocket.value?.close();
}; };
@@ -33,11 +148,15 @@ export const useWebSocketStore = defineStore("websocket", () => {
websocket.value.onopen = () => { websocket.value.onopen = () => {
connected.value = true; connected.value = true;
console.log("WebSocket连接成功");
let pingIntervalId: NodeJS.Timeout | undefined; let pingIntervalId: NodeJS.Timeout | undefined;
if (pingIntervalId) clearInterval(pingIntervalId); if (pingIntervalId) clearInterval(pingIntervalId);
pingIntervalId = setInterval(() => send("ping"), 30 * 1000); pingIntervalId = setInterval(() => {
// 修改ping格式为JSON格式与后端保持一致
send(JSON.stringify({ type: "ping" }));
}, 30 * 1000);
if (websocket.value) { if (websocket.value) {
websocket.value.onmessage = onmessage; websocket.value.onmessage = onmessage;
@@ -45,20 +164,28 @@ export const useWebSocketStore = defineStore("websocket", () => {
websocket.value.onerror = (e: Event) => { websocket.value.onerror = (e: Event) => {
console.error(`WebSocket错误:${(e as ErrorEvent).message}`); console.error(`WebSocket错误:${(e as ErrorEvent).message}`);
}; };
websocket.value.onclose = () => {
websocket.value.onclose = (e: CloseEvent) => {
connected.value = false; connected.value = false;
console.log(`WebSocket连接关闭: ${e.code} ${e.reason}`);
setTimeout(() => { setTimeout(() => {
console.log("尝试重新连接WebSocket...");
connect(); // 尝试重新连接 connect(); // 尝试重新连接
}, 1000); // 1秒后重试连接 }, 1000); // 1秒后重试连接
}; };
} }
}; };
websocket.value.onerror = (e: Event) => {
console.error("WebSocket连接错误:", e);
};
}; };
return { return {
websocket, websocket,
connected, connected,
send, send,
sendBinary,
close, close,
connect connect
}; };

View File

@@ -118,7 +118,10 @@ export const useChatStore = defineStore("chat", () => {
historyMessages.value[historyMessages.value.length - 1].thinking = historyMessages.value[historyMessages.value.length - 1].thinking =
thinkingContent; thinkingContent;
} }
); ).then(() => {
historyMessages.value[historyMessages.value.length - 1].id =
new Date().getTime().toString();
});
} }
} }
}, },

View File

@@ -1,3 +1,4 @@
export * from "./asr_store"; export * from "./asr_store";
export * from "./chat_store"; export * from "./chat_store";
export * from "./layout_store"; export * from "./layout_store";
export * from "./tts_store";

302
web/src/stores/tts_store.ts Normal file
View File

@@ -0,0 +1,302 @@
import { useAudioWebSocket } from "@/services";
import { createAudioUrl, mergeAudioChunks } from "@/utils";
interface AudioState {
isPlaying: boolean;
isLoading: boolean;
audioElement: HTMLAudioElement | null;
audioUrl: string | null;
audioChunks: ArrayBuffer[];
hasError: boolean;
errorMessage: string;
}
export const useTtsStore = defineStore("tts", () => {
// 多音频状态管理 - 以消息ID为key
const audioStates = ref<Map<string, AudioState>>(new Map());
// 当前活跃的转换请求(保留用于兼容性)
const activeConversion = ref<string | null>(null);
// 会话状态
const hasActiveSession = ref(false);
// WebSocket连接
const { sendMessage, ensureConnection } = useAudioWebSocket();
/**
* 获取或创建音频状态
*/
const getAudioState = (messageId: string): AudioState => {
if (!audioStates.value.has(messageId)) {
audioStates.value.set(messageId, {
isPlaying: false,
isLoading: false,
audioElement: null,
audioUrl: null,
audioChunks: [],
hasError: false,
errorMessage: ""
});
}
return audioStates.value.get(messageId)!;
};
/**
* 发送文本进行TTS转换
*/
const convertText = async (text: string, messageId: string) => {
try {
await ensureConnection();
// 暂停其他正在播放的音频
pauseAll();
// 获取当前消息的状态
const state = getAudioState(messageId);
// 清理之前的音频和错误状态
clearAudioState(state);
state.isLoading = true;
state.audioChunks = [];
// 设置当前活跃转换
activeConversion.value = messageId;
hasActiveSession.value = true;
// 发送文本到TTS服务
sendMessage(JSON.stringify({ type: "tts_text", text, messageId }));
} catch (error) {
handleError(`连接失败: ${error}`, messageId);
}
};
/**
* 处理接收到的音频数据 - 修改为支持messageId参数
*/
const handleAudioData = (data: ArrayBuffer, messageId?: string) => {
// 如果传递了messageId就使用它否则使用activeConversion
const targetMessageId = messageId || activeConversion.value;
if (!targetMessageId) {
console.warn("handleAudioData: 没有有效的messageId");
return;
}
console.log(`接收音频数据 [${targetMessageId}],大小:`, data.byteLength);
const state = getAudioState(targetMessageId);
state.audioChunks.push(data);
};
/**
* 完成TTS转换创建播放器并自动播放 - 修改为支持messageId参数
*/
const finishConversion = async (messageId?: string) => {
// 如果传递了messageId就使用它否则使用activeConversion
const targetMessageId = messageId || activeConversion.value;
if (!targetMessageId) {
console.warn("finishConversion: 没有有效的messageId");
return;
}
const state = getAudioState(targetMessageId);
console.log(
`完成TTS转换 [${targetMessageId}],音频片段数量:`,
state.audioChunks.length
);
if (state.audioChunks.length === 0) {
handleError("没有接收到音频数据", targetMessageId);
return;
}
try {
// 合并音频片段
const mergedAudio = mergeAudioChunks(state.audioChunks);
console.log(
`合并后音频大小 [${targetMessageId}]:`,
mergedAudio.byteLength
);
// 创建音频URL和元素
state.audioUrl = createAudioUrl(mergedAudio);
state.audioElement = new Audio(state.audioUrl);
// 设置音频事件
setupAudioEvents(state, targetMessageId);
state.isLoading = false;
// 清除activeConversion如果是当前活跃的
if (activeConversion.value === targetMessageId) {
activeConversion.value = null;
}
console.log(`TTS音频准备完成 [${targetMessageId}],开始自动播放`);
// 自动播放
await play(targetMessageId);
} catch (error) {
handleError(`音频处理失败: ${error}`, targetMessageId);
}
};
/**
* 设置音频事件监听
*/
const setupAudioEvents = (state: AudioState, messageId: string) => {
if (!state.audioElement) return;
const audio = state.audioElement;
audio.addEventListener("ended", () => {
state.isPlaying = false;
console.log(`音频播放结束 [${messageId}]`);
});
audio.addEventListener("error", (e) => {
console.error(`音频播放错误 [${messageId}]:`, e);
handleError("音频播放失败", messageId);
});
audio.addEventListener("canplaythrough", () => {
console.log(`音频可以播放 [${messageId}]`);
});
};
/**
* 播放指定消息的音频
*/
const play = async (messageId: string) => {
const state = getAudioState(messageId);
if (!state.audioElement) {
handleError("音频未准备好", messageId);
return;
}
try {
// 暂停其他正在播放的音频
pauseAll(messageId);
await state.audioElement.play();
state.isPlaying = true;
state.hasError = false;
state.errorMessage = "";
console.log(`开始播放音频 [${messageId}]`);
} catch (error) {
handleError(`播放失败: ${error}`, messageId);
}
};
/**
* 暂停指定消息的音频
*/
const pause = (messageId: string) => {
const state = getAudioState(messageId);
if (!state.audioElement) return;
state.audioElement.pause();
state.isPlaying = false;
console.log(`暂停音频 [${messageId}]`);
};
/**
* 暂停所有音频
*/
const pauseAll = (excludeMessageId?: string) => {
audioStates.value.forEach((state, messageId) => {
if (excludeMessageId && messageId === excludeMessageId) return;
if (state.isPlaying && state.audioElement) {
state.audioElement.pause();
state.isPlaying = false;
}
});
};
/**
* 处理TTS错误 - 修改为支持messageId参数
*/
const handleError = (errorMsg: string, messageId?: string) => {
// 如果传递了messageId就使用它否则使用activeConversion
const targetMessageId = messageId || activeConversion.value;
if (!targetMessageId) {
console.error(`TTS错误 (无messageId): ${errorMsg}`);
return;
}
console.error(`TTS错误 [${targetMessageId}]: ${errorMsg}`);
const state = getAudioState(targetMessageId);
state.hasError = true;
state.errorMessage = errorMsg;
state.isLoading = false;
if (activeConversion.value === targetMessageId) {
activeConversion.value = null;
hasActiveSession.value = false;
}
};
/**
* 清理指定消息的音频资源
*/
const clearAudio = (messageId: string) => {
const state = getAudioState(messageId);
clearAudioState(state);
audioStates.value.delete(messageId);
};
/**
* 清理音频状态
*/
const clearAudioState = (state: AudioState) => {
if (state.audioElement) {
state.audioElement.pause();
state.audioElement = null;
}
if (state.audioUrl) {
URL.revokeObjectURL(state.audioUrl);
state.audioUrl = null;
}
state.isPlaying = false;
state.audioChunks = [];
state.hasError = false;
state.errorMessage = "";
};
// 状态查询方法
const isPlaying = (messageId: string) => getAudioState(messageId).isPlaying;
const isLoading = (messageId: string) => getAudioState(messageId).isLoading;
const hasAudio = (messageId: string) =>
!!getAudioState(messageId).audioElement;
const hasError = (messageId: string) => getAudioState(messageId).hasError;
const getErrorMessage = (messageId: string) =>
getAudioState(messageId).errorMessage;
// 组件卸载时清理所有资源
onUnmounted(() => {
audioStates.value.forEach((state) => clearAudioState(state));
audioStates.value.clear();
});
return {
// 状态查询方法
isPlaying,
isLoading,
hasAudio,
hasError,
getErrorMessage,
// 核心方法
convertText,
handleAudioData,
finishConversion,
play,
pause,
pauseAll,
clearAudio,
handleError
};
});

20
web/src/utils/audio.ts Normal file
View File

@@ -0,0 +1,20 @@
// 合并音频片段
export const mergeAudioChunks = (chunks: ArrayBuffer[]): Uint8Array => {
const totalLength = chunks.reduce((acc, chunk) => acc + chunk.byteLength, 0);
const merged = new Uint8Array(totalLength);
let offset = 0;
chunks.forEach((chunk) => {
merged.set(new Uint8Array(chunk), offset);
offset += chunk.byteLength;
});
return merged;
};
// 创建音频播放URL
export const createAudioUrl = (
audioData: Uint8Array,
mimeType = "audio/mp3"
): string => {
const blob = new Blob([audioData as BlobPart], { type: mimeType });
return URL.createObjectURL(blob);
};

7
web/src/utils/format.ts Normal file
View File

@@ -0,0 +1,7 @@
export const formatTime = (seconds: number): string => {
if (Number.isNaN(seconds) || !Number.isFinite(seconds)) return "00:00";
const minutes = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${minutes.toString().padStart(2, "0")}:${secs.toString().padStart(2, "0")}`;
};

View File

@@ -1,4 +1,6 @@
export * from "./audio";
export * from "./context"; export * from "./context";
export * from "./format";
export * from "./media"; export * from "./media";
export * from "./pcm"; export * from "./pcm";
export * from "./title"; export * from "./title";

View File

@@ -1,5 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import type { SelectGroupOption, SelectOption } from "naive-ui"; import type { SelectGroupOption, SelectOption } from "naive-ui";
import type { Message } from "@/interfaces";
import { throttle } from "lodash-es"; import { throttle } from "lodash-es";
import AIAvatar from "@/assets/ai_avatar.png"; import AIAvatar from "@/assets/ai_avatar.png";
import { import {
@@ -25,17 +26,18 @@ const scrollbarRef = ref<HTMLElement | null>(null);
const options = ref<Array<SelectGroupOption | SelectOption>>([]); const options = ref<Array<SelectGroupOption | SelectOption>>([]);
// NCollapse 组件的折叠状态 // NCollapse 组件的折叠状态
const collapseActive = ref<string[]>( const collapseActive = ref<string[]>(
historyMessages.value.map((_, idx) => String(idx)) historyMessages.value.map((msg, idx) => String(msg.id ?? idx))
); );
const getName = (idx: number) => String(idx); const getName = (msg: Message, idx: number) => String(msg.id ?? idx);
// TODO: bugfix: 未能正确展开
watch( watch(
historyMessages, historyMessages,
(newVal, oldVal) => { (newVal, oldVal) => {
// 取所有name // 取所有name
const newNames = newVal.map((_, idx) => getName(idx)); const newNames = newVal.map((msg, idx) => getName(msg, idx));
const oldNames = oldVal ? oldVal.map((_, idx) => getName(idx)) : []; const oldNames = oldVal ? oldVal.map((msg, idx) => getName(msg, idx)) : [];
// 找出新增的name // 找出新增的name
const addedNames = newNames.filter((name) => !oldNames.includes(name)); const addedNames = newNames.filter((name) => !oldNames.includes(name));
// 保留原有已展开项 // 保留原有已展开项
@@ -45,9 +47,10 @@ watch(
// 新增的默认展开 // 新增的默认展开
collapseActive.value = [...currentActive, ...addedNames]; collapseActive.value = [...currentActive, ...addedNames];
}, },
{ deep: true } { immediate: true, deep: true }
); );
// 处理折叠项的点击事件,切换折叠状态
const handleItemHeaderClick = (name: string) => { const handleItemHeaderClick = (name: string) => {
if (collapseActive.value.includes(name)) { if (collapseActive.value.includes(name)) {
collapseActive.value = collapseActive.value.filter((n) => n !== name); collapseActive.value = collapseActive.value.filter((n) => n !== name);
@@ -177,9 +180,15 @@ onMounted(() => {
:expanded-names="collapseActive[idx]" :expanded-names="collapseActive[idx]"
> >
<NCollapseItem <NCollapseItem
:title="thinking && idx === historyMessages.length - 1 ? '思考中...' : '已深度思考'" :title="
:name="getName(idx)" thinking && idx === historyMessages.length - 1
@item-header-click="() => handleItemHeaderClick(getName(idx))" ? '思考中...'
: '已深度思考'
"
:name="getName(msg, idx)"
@item-header-click="
() => handleItemHeaderClick(getName(msg, idx))
"
> >
<div <div
class="text-[#7A7A7A] mb-4 border-l-2 border-[#E5E5E5] ml-2 pl-2" class="text-[#7A7A7A] mb-4 border-l-2 border-[#E5E5E5] ml-2 pl-2"
@@ -190,6 +199,9 @@ onMounted(() => {
</NCollapse> </NCollapse>
<!-- 内容↓ 思维链↑ --> <!-- 内容↓ 思维链↑ -->
<markdown :content="msg.content || ''" /> <markdown :content="msg.content || ''" />
<div v-if="msg.role !== 'user'" class="mt-2">
<tts :text="msg.content || ''" :message-id="msg.id!" />
</div>
<NDivider /> <NDivider />
</div> </div>
</div> </div>