feat: 支持音色切换
This commit is contained in:
10
backend/app/api/v1/endpoints/speaker.py
Normal file
10
backend/app/api/v1/endpoints/speaker.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from app.constants.tts import SPEAKER_DATA
|
||||
from app.schemas import SpeakerResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/list", response_model=SpeakerResponse)
|
||||
async def get_model_vendors():
|
||||
return SpeakerResponse(data=SPEAKER_DATA)
|
||||
@@ -340,55 +340,51 @@ async def save_audio_file(audio_data: bytes, filename: str) -> str:
|
||||
|
||||
|
||||
# 处理单个TTS任务
|
||||
async def process_tts_task(websocket, message_id: str, text: str):
|
||||
async def process_tts_task(websocket, message_id: str, text: str, speaker: str = None):
|
||||
"""处理单个TTS任务(独立协程)"""
|
||||
tts_state = None
|
||||
try:
|
||||
print(f"开始处理TTS任务 [{message_id}]: {text}")
|
||||
# 使用传入的speaker,如果没有则使用默认的
|
||||
selected_speaker = speaker if speaker else SPEAKER
|
||||
|
||||
try:
|
||||
print(f"开始处理TTS任务 [{message_id}]: {text}, 使用说话人: {selected_speaker}")
|
||||
# 获取TTS状态
|
||||
tts_state = tts_manager.get_tts_state(websocket, message_id)
|
||||
if not tts_state:
|
||||
raise Exception(f"找不到TTS状态: {message_id}")
|
||||
|
||||
tts_state.is_processing = True
|
||||
# 生成音频文件名
|
||||
tts_state.audio_filename = generate_audio_filename()
|
||||
|
||||
# 创建独立的TTS连接
|
||||
tts_state.volc_ws = await create_tts_connection()
|
||||
|
||||
# 创建会话
|
||||
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
|
||||
tts_manager.register_session(tts_state.session_id, message_id)
|
||||
|
||||
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
|
||||
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||||
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||||
serial_method=JSON).as_bytes()
|
||||
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
|
||||
payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER)
|
||||
# 使用选择的speaker
|
||||
payload = get_payload_bytes(event=EVENT_StartSession, speaker=selected_speaker)
|
||||
await send_event(tts_state.volc_ws, header, optional, payload)
|
||||
|
||||
raw_data = await tts_state.volc_ws.recv()
|
||||
res = parser_response(raw_data)
|
||||
if res.optional.event != EVENT_SessionStarted:
|
||||
raise Exception("TTS会话启动失败")
|
||||
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
|
||||
|
||||
# 发送文本到TTS服务
|
||||
print(f"发送文本到TTS服务 [{message_id}]...")
|
||||
header = Header(message_type=FULL_CLIENT_REQUEST,
|
||||
message_type_specific_flags=MsgTypeFlagWithEvent,
|
||||
serial_method=JSON).as_bytes()
|
||||
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
|
||||
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER)
|
||||
# 使用选择的speaker
|
||||
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=selected_speaker)
|
||||
await send_event(tts_state.volc_ws, header, optional, payload)
|
||||
|
||||
# 接收TTS响应并发送到前端
|
||||
print(f"开始接收TTS响应 [{message_id}]...")
|
||||
audio_count = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await asyncio.wait_for(
|
||||
@@ -396,17 +392,13 @@ async def process_tts_task(websocket, message_id: str, text: str):
|
||||
timeout=30
|
||||
)
|
||||
res = parser_response(raw_data)
|
||||
|
||||
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
|
||||
|
||||
if res.optional.event == EVENT_TTSSentenceEnd:
|
||||
print(f"句子结束事件 [{message_id}] - 直接完成")
|
||||
break
|
||||
|
||||
elif res.optional.event == EVENT_SessionFinished:
|
||||
print(f"收到会话结束事件 [{message_id}]")
|
||||
break
|
||||
|
||||
elif res.optional.event == EVENT_TTSResponse:
|
||||
audio_count += 1
|
||||
print(f"收到音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
|
||||
@@ -423,10 +415,8 @@ async def process_tts_task(websocket, message_id: str, text: str):
|
||||
})
|
||||
else:
|
||||
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"TTS响应超时 [{message_id}],强制结束")
|
||||
|
||||
# 异步保存音频文件
|
||||
if tts_state.audio_data:
|
||||
file_path = await save_audio_file(
|
||||
@@ -435,15 +425,15 @@ async def process_tts_task(websocket, message_id: str, text: str):
|
||||
)
|
||||
print(f"音频文件已保存 [{message_id}]: {file_path}")
|
||||
|
||||
# 发送完成消息,包含文件路径
|
||||
# 发送完成消息,包含文件路径和使用的speaker
|
||||
await websocket.send_json({
|
||||
"type": "tts_audio_complete",
|
||||
"messageId": message_id,
|
||||
"audioFile": tts_state.audio_filename,
|
||||
"audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None
|
||||
"audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None,
|
||||
"speaker": selected_speaker
|
||||
})
|
||||
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包")
|
||||
|
||||
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包,使用说话人: {selected_speaker}")
|
||||
except asyncio.CancelledError:
|
||||
print(f"TTS任务被取消 [{message_id}]")
|
||||
await websocket.send_json({
|
||||
@@ -474,14 +464,14 @@ async def process_tts_task(websocket, message_id: str, text: str):
|
||||
|
||||
|
||||
# 启动TTS文本转换
|
||||
async def handle_tts_text(websocket, message_id: str, text: str):
|
||||
async def handle_tts_text(websocket, message_id: str, text: str, speaker: str = None):
|
||||
"""启动TTS文本转换"""
|
||||
# 创建新的TTS状态
|
||||
print(speaker)
|
||||
tts_state = tts_manager.add_tts_state(websocket, message_id)
|
||||
|
||||
# 启动异步任务
|
||||
# 启动异步任务,传入speaker参数
|
||||
tts_state.task = asyncio.create_task(
|
||||
process_tts_task(websocket, message_id, text)
|
||||
process_tts_task(websocket, message_id, text, speaker)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from . import tts
|
||||
from app.constants.model_data import tip_message, base_url, headers
|
||||
|
||||
|
||||
async def process_voice_conversation(websocket: WebSocket, asr_text: str, message_id: str):
|
||||
async def process_voice_conversation(websocket: WebSocket, asr_text: str, message_id: str, speaker: str):
|
||||
try:
|
||||
print(f"开始处理语音对话 [{message_id}]: {asr_text}")
|
||||
|
||||
@@ -92,7 +92,7 @@ async def process_voice_conversation(websocket: WebSocket, asr_text: str, messag
|
||||
|
||||
# 启动TTS处理完整内容
|
||||
print(f"启动完整TTS处理 [{message_id}]: {full_response}")
|
||||
await tts.handle_tts_text(websocket, message_id, full_response)
|
||||
await tts.handle_tts_text(websocket, message_id, full_response, speaker)
|
||||
|
||||
except Exception as e:
|
||||
print(f"语音对话处理异常 [{message_id}]: {e}")
|
||||
|
||||
@@ -64,7 +64,8 @@ async def websocket_online_count(websocket: WebSocket):
|
||||
# 从data中获取messageId,如果不存在则生成一个新的ID
|
||||
message_id = data.get("messageId", "voice_" + str(uuid.uuid4()))
|
||||
if data.get("voiceConversation"):
|
||||
await process_voice_conversation(websocket, asr_text, message_id)
|
||||
speaker = data.get("speaker")
|
||||
await process_voice_conversation(websocket, asr_text, message_id, speaker)
|
||||
else:
|
||||
await websocket.send_json({"type": "asr_result", "result": asr_text})
|
||||
temp_buffer = bytes()
|
||||
@@ -73,6 +74,7 @@ async def websocket_online_count(websocket: WebSocket):
|
||||
elif msg_type == "tts_text":
|
||||
message_id = data.get("messageId")
|
||||
text = data.get("text", "")
|
||||
speaker = data.get("speaker")
|
||||
|
||||
if not message_id:
|
||||
await websocket.send_json({
|
||||
@@ -83,7 +85,7 @@ async def websocket_online_count(websocket: WebSocket):
|
||||
|
||||
print(f"收到TTS文本请求 [{message_id}]: {text}")
|
||||
try:
|
||||
await tts.handle_tts_text(websocket, message_id, text)
|
||||
await tts.handle_tts_text(websocket, message_id, text, speaker)
|
||||
except Exception as e:
|
||||
print(f"TTS文本处理异常 [{message_id}]: {e}")
|
||||
await websocket.send_json({
|
||||
|
||||
@@ -4,4 +4,101 @@
|
||||
|
||||
APP_ID = '2138450044'
|
||||
TOKEN = 'V04_QumeQZhJrQ_In1Z0VBQm7n0ttMNO'
|
||||
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
|
||||
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
|
||||
|
||||
SPEAKER_DATA = [
|
||||
{
|
||||
"category": "趣味口音",
|
||||
"speakers": [
|
||||
{
|
||||
"speaker_id": "zh_male_jingqiangkanye_moon_bigtts",
|
||||
"speaker_name": "京腔侃爷/Harmony",
|
||||
"language": "中文-北京口音、英文",
|
||||
"platforms": ["豆包", "Cici", "web demo"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_wanwanxiaohe_moon_bigtts",
|
||||
"speaker_name": "湾湾小何",
|
||||
"language": "中文-台湾口音",
|
||||
"platforms": ["豆包", "Cici"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_wanqudashu_moon_bigtts",
|
||||
"speaker_name": "湾区大叔",
|
||||
"language": "中文-广东口音",
|
||||
"platforms": ["豆包", "Cici"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_daimengchuanmei_moon_bigtts",
|
||||
"speaker_name": "呆萌川妹",
|
||||
"language": "中文-四川口音",
|
||||
"platforms": ["豆包", "Cici"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_guozhoudege_moon_bigtts",
|
||||
"speaker_name": "广州德哥",
|
||||
"language": "中文-广东口音",
|
||||
"platforms": ["豆包", "Cici"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_beijingxiaoye_moon_bigtts",
|
||||
"speaker_name": "北京小爷",
|
||||
"language": "中文-北京口音",
|
||||
"platforms": ["豆包"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_haoyuxiaoge_moon_bigtts",
|
||||
"speaker_name": "浩宇小哥",
|
||||
"language": "中文-青岛口音",
|
||||
"platforms": ["豆包"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_guangxiyuanzhou_moon_bigtts",
|
||||
"speaker_name": "广西远舟",
|
||||
"language": "中文-广西口音",
|
||||
"platforms": ["豆包"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_meituojieer_moon_bigtts",
|
||||
"speaker_name": "妹坨洁儿",
|
||||
"language": "中文-长沙口音",
|
||||
"platforms": ["豆包", "剪映"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_yuzhouzixuan_moon_bigtts",
|
||||
"speaker_name": "豫州子轩",
|
||||
"language": "中文-河南口音",
|
||||
"platforms": ["豆包"]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "角色扮演",
|
||||
"speakers": [
|
||||
{
|
||||
"speaker_id": "zh_male_naiqimengwa_mars_bigtts",
|
||||
"speaker_name": "奶气萌娃",
|
||||
"language": "中文",
|
||||
"platforms": ["剪映", "豆包"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_popo_mars_bigtts",
|
||||
"speaker_name": "婆婆",
|
||||
"language": "中文",
|
||||
"platforms": ["剪映C端", "抖音", "豆包"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_female_gaolengyujie_moon_bigtts",
|
||||
"speaker_name": "高冷御姐",
|
||||
"language": "中文",
|
||||
"platforms": ["豆包", "Cici"]
|
||||
},
|
||||
{
|
||||
"speaker_id": "zh_male_aojiaobazong_moon_bigtts",
|
||||
"speaker_name": "傲娇霸总",
|
||||
"language": "中文",
|
||||
"platforms": ["豆包"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from app.api.v1.endpoints import chat, model, websocket_service
|
||||
from app.api.v1.endpoints import chat, model, websocket_service,speaker
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -9,6 +9,7 @@ app.include_router(websocket_service.router, prefix="", tags=["websocket_service
|
||||
app.include_router(chat.router, prefix="/v1/chat", tags=["chat"])
|
||||
# 获取模型列表服务
|
||||
app.include_router(model.router, prefix="/v1/model", tags=["model_list"])
|
||||
app.include_router(speaker.router, prefix="/v1/speaker", tags=["speaker_list"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -5,4 +5,5 @@ from .chat import (
|
||||
ModelInfo,
|
||||
VendorModelList,
|
||||
VendorModelResponse,
|
||||
SpeakerResponse
|
||||
)
|
||||
|
||||
@@ -33,3 +33,20 @@ class VendorModelList(BaseModel):
|
||||
|
||||
class VendorModelResponse(BaseModel):
|
||||
data: List[VendorModelList]
|
||||
|
||||
|
||||
# Speaker相关模型
|
||||
class Speaker(BaseModel):
|
||||
speaker_id: str
|
||||
speaker_name: str
|
||||
language: str
|
||||
platforms: List[str]
|
||||
|
||||
|
||||
class CategorySpeakers(BaseModel):
|
||||
category: str
|
||||
speakers: List[Speaker]
|
||||
|
||||
|
||||
class SpeakerResponse(BaseModel):
|
||||
data: List[CategorySpeakers]
|
||||
|
||||
Reference in New Issue
Block a user