feat: 支持音色切换

This commit is contained in:
2025-07-01 01:27:29 +08:00
parent faa4ca20b1
commit ec6bd7db88
14 changed files with 308 additions and 57 deletions

View File

@@ -0,0 +1,10 @@
from fastapi import APIRouter
from app.constants.tts import SPEAKER_DATA
from app.schemas import SpeakerResponse
router = APIRouter()
@router.get("/list", response_model=SpeakerResponse)
async def get_model_vendors():
return SpeakerResponse(data=SPEAKER_DATA)

View File

@@ -340,55 +340,51 @@ async def save_audio_file(audio_data: bytes, filename: str) -> str:
# 处理单个TTS任务
async def process_tts_task(websocket, message_id: str, text: str):
async def process_tts_task(websocket, message_id: str, text: str, speaker: str = None):
"""处理单个TTS任务独立协程"""
tts_state = None
try:
print(f"开始处理TTS任务 [{message_id}]: {text}")
# 使用传入的speaker如果没有则使用默认的
selected_speaker = speaker if speaker else SPEAKER
try:
print(f"开始处理TTS任务 [{message_id}]: {text}, 使用说话人: {selected_speaker}")
# 获取TTS状态
tts_state = tts_manager.get_tts_state(websocket, message_id)
if not tts_state:
raise Exception(f"找不到TTS状态: {message_id}")
tts_state.is_processing = True
# 生成音频文件名
tts_state.audio_filename = generate_audio_filename()
# 创建独立的TTS连接
tts_state.volc_ws = await create_tts_connection()
# 创建会话
tts_state.session_id = uuid.uuid4().__str__().replace('-', '')
tts_manager.register_session(tts_state.session_id, message_id)
print(f"创建TTS会话 [{message_id}]: {tts_state.session_id}")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_StartSession, sessionId=tts_state.session_id).as_bytes()
payload = get_payload_bytes(event=EVENT_StartSession, speaker=SPEAKER)
# 使用选择的speaker
payload = get_payload_bytes(event=EVENT_StartSession, speaker=selected_speaker)
await send_event(tts_state.volc_ws, header, optional, payload)
raw_data = await tts_state.volc_ws.recv()
res = parser_response(raw_data)
if res.optional.event != EVENT_SessionStarted:
raise Exception("TTS会话启动失败")
print(f"TTS会话创建成功 [{message_id}]: {tts_state.session_id}")
# 发送文本到TTS服务
print(f"发送文本到TTS服务 [{message_id}]...")
header = Header(message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=MsgTypeFlagWithEvent,
serial_method=JSON).as_bytes()
optional = Optional(event=EVENT_TaskRequest, sessionId=tts_state.session_id).as_bytes()
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=SPEAKER)
# 使用选择的speaker
payload = get_payload_bytes(event=EVENT_TaskRequest, text=text, speaker=selected_speaker)
await send_event(tts_state.volc_ws, header, optional, payload)
# 接收TTS响应并发送到前端
print(f"开始接收TTS响应 [{message_id}]...")
audio_count = 0
try:
while True:
raw_data = await asyncio.wait_for(
@@ -396,17 +392,13 @@ async def process_tts_task(websocket, message_id: str, text: str):
timeout=30
)
res = parser_response(raw_data)
print(f"收到TTS事件 [{message_id}]: {res.optional.event}")
if res.optional.event == EVENT_TTSSentenceEnd:
print(f"句子结束事件 [{message_id}] - 直接完成")
break
elif res.optional.event == EVENT_SessionFinished:
print(f"收到会话结束事件 [{message_id}]")
break
elif res.optional.event == EVENT_TTSResponse:
audio_count += 1
print(f"收到音频数据 [{message_id}] #{audio_count},大小: {len(res.payload)}")
@@ -423,10 +415,8 @@ async def process_tts_task(websocket, message_id: str, text: str):
})
else:
print(f"未知TTS事件 [{message_id}]: {res.optional.event}")
except asyncio.TimeoutError:
print(f"TTS响应超时 [{message_id}],强制结束")
# 异步保存音频文件
if tts_state.audio_data:
file_path = await save_audio_file(
@@ -435,15 +425,15 @@ async def process_tts_task(websocket, message_id: str, text: str):
)
print(f"音频文件已保存 [{message_id}]: {file_path}")
# 发送完成消息,包含文件路径
# 发送完成消息,包含文件路径和使用的speaker
await websocket.send_json({
"type": "tts_audio_complete",
"messageId": message_id,
"audioFile": tts_state.audio_filename,
"audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None
"audioPath": os.path.join(TEMP_AUDIO_DIR, tts_state.audio_filename) if tts_state.audio_data else None,
"speaker": selected_speaker
})
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包")
print(f"TTS处理完成 [{message_id}],共发送 {audio_count} 个音频包,使用说话人: {selected_speaker}")
except asyncio.CancelledError:
print(f"TTS任务被取消 [{message_id}]")
await websocket.send_json({
@@ -474,14 +464,14 @@ async def process_tts_task(websocket, message_id: str, text: str):
# 启动TTS文本转换
async def handle_tts_text(websocket, message_id: str, text: str):
async def handle_tts_text(websocket, message_id: str, text: str, speaker: str = None):
"""启动TTS文本转换"""
# 创建新的TTS状态
print(speaker)
tts_state = tts_manager.add_tts_state(websocket, message_id)
# 启动异步任务
# 启动异步任务传入speaker参数
tts_state.task = asyncio.create_task(
process_tts_task(websocket, message_id, text)
process_tts_task(websocket, message_id, text, speaker)
)

View File

@@ -8,7 +8,7 @@ from . import tts
from app.constants.model_data import tip_message, base_url, headers
async def process_voice_conversation(websocket: WebSocket, asr_text: str, message_id: str):
async def process_voice_conversation(websocket: WebSocket, asr_text: str, message_id: str, speaker: str):
try:
print(f"开始处理语音对话 [{message_id}]: {asr_text}")
@@ -92,7 +92,7 @@ async def process_voice_conversation(websocket: WebSocket, asr_text: str, messag
# 启动TTS处理完整内容
print(f"启动完整TTS处理 [{message_id}]: {full_response}")
await tts.handle_tts_text(websocket, message_id, full_response)
await tts.handle_tts_text(websocket, message_id, full_response, speaker)
except Exception as e:
print(f"语音对话处理异常 [{message_id}]: {e}")

View File

@@ -64,7 +64,8 @@ async def websocket_online_count(websocket: WebSocket):
# 从data中获取messageId如果不存在则生成一个新的ID
message_id = data.get("messageId", "voice_" + str(uuid.uuid4()))
if data.get("voiceConversation"):
await process_voice_conversation(websocket, asr_text, message_id)
speaker = data.get("speaker")
await process_voice_conversation(websocket, asr_text, message_id, speaker)
else:
await websocket.send_json({"type": "asr_result", "result": asr_text})
temp_buffer = bytes()
@@ -73,6 +74,7 @@ async def websocket_online_count(websocket: WebSocket):
elif msg_type == "tts_text":
message_id = data.get("messageId")
text = data.get("text", "")
speaker = data.get("speaker")
if not message_id:
await websocket.send_json({
@@ -83,7 +85,7 @@ async def websocket_online_count(websocket: WebSocket):
print(f"收到TTS文本请求 [{message_id}]: {text}")
try:
await tts.handle_tts_text(websocket, message_id, text)
await tts.handle_tts_text(websocket, message_id, text, speaker)
except Exception as e:
print(f"TTS文本处理异常 [{message_id}]: {e}")
await websocket.send_json({

View File

@@ -4,4 +4,101 @@
APP_ID = '2138450044'
TOKEN = 'V04_QumeQZhJrQ_In1Z0VBQm7n0ttMNO'
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
SPEAKER = 'zh_male_beijingxiaoye_moon_bigtts'
SPEAKER_DATA = [
{
"category": "趣味口音",
"speakers": [
{
"speaker_id": "zh_male_jingqiangkanye_moon_bigtts",
"speaker_name": "京腔侃爷/Harmony",
"language": "中文-北京口音、英文",
"platforms": ["豆包", "Cici", "web demo"]
},
{
"speaker_id": "zh_female_wanwanxiaohe_moon_bigtts",
"speaker_name": "湾湾小何",
"language": "中文-台湾口音",
"platforms": ["豆包", "Cici"]
},
{
"speaker_id": "zh_female_wanqudashu_moon_bigtts",
"speaker_name": "湾区大叔",
"language": "中文-广东口音",
"platforms": ["豆包", "Cici"]
},
{
"speaker_id": "zh_female_daimengchuanmei_moon_bigtts",
"speaker_name": "呆萌川妹",
"language": "中文-四川口音",
"platforms": ["豆包", "Cici"]
},
{
"speaker_id": "zh_male_guozhoudege_moon_bigtts",
"speaker_name": "广州德哥",
"language": "中文-广东口音",
"platforms": ["豆包", "Cici"]
},
{
"speaker_id": "zh_male_beijingxiaoye_moon_bigtts",
"speaker_name": "北京小爷",
"language": "中文-北京口音",
"platforms": ["豆包"]
},
{
"speaker_id": "zh_male_haoyuxiaoge_moon_bigtts",
"speaker_name": "浩宇小哥",
"language": "中文-青岛口音",
"platforms": ["豆包"]
},
{
"speaker_id": "zh_male_guangxiyuanzhou_moon_bigtts",
"speaker_name": "广西远舟",
"language": "中文-广西口音",
"platforms": ["豆包"]
},
{
"speaker_id": "zh_female_meituojieer_moon_bigtts",
"speaker_name": "妹坨洁儿",
"language": "中文-长沙口音",
"platforms": ["豆包", "剪映"]
},
{
"speaker_id": "zh_male_yuzhouzixuan_moon_bigtts",
"speaker_name": "豫州子轩",
"language": "中文-河南口音",
"platforms": ["豆包"]
}
]
},
{
"category": "角色扮演",
"speakers": [
{
"speaker_id": "zh_male_naiqimengwa_mars_bigtts",
"speaker_name": "奶气萌娃",
"language": "中文",
"platforms": ["剪映", "豆包"]
},
{
"speaker_id": "zh_female_popo_mars_bigtts",
"speaker_name": "婆婆",
"language": "中文",
"platforms": ["剪映C端", "抖音", "豆包"]
},
{
"speaker_id": "zh_female_gaolengyujie_moon_bigtts",
"speaker_name": "高冷御姐",
"language": "中文",
"platforms": ["豆包", "Cici"]
},
{
"speaker_id": "zh_male_aojiaobazong_moon_bigtts",
"speaker_name": "傲娇霸总",
"language": "中文",
"platforms": ["豆包"]
}
]
}
]

View File

@@ -1,5 +1,5 @@
from fastapi import FastAPI
from app.api.v1.endpoints import chat, model, websocket_service
from app.api.v1.endpoints import chat, model, websocket_service,speaker
app = FastAPI()
@@ -9,6 +9,7 @@ app.include_router(websocket_service.router, prefix="", tags=["websocket_service
app.include_router(chat.router, prefix="/v1/chat", tags=["chat"])
# 获取模型列表服务
app.include_router(model.router, prefix="/v1/model", tags=["model_list"])
app.include_router(speaker.router, prefix="/v1/speaker", tags=["speaker_list"])
if __name__ == "__main__":
import uvicorn

View File

@@ -5,4 +5,5 @@ from .chat import (
ModelInfo,
VendorModelList,
VendorModelResponse,
SpeakerResponse
)

View File

@@ -33,3 +33,20 @@ class VendorModelList(BaseModel):
class VendorModelResponse(BaseModel):
data: List[VendorModelList]
# Speaker相关模型
class Speaker(BaseModel):
speaker_id: str
speaker_name: str
language: str
platforms: List[str]
class CategorySpeakers(BaseModel):
category: str
speakers: List[Speaker]
class SpeakerResponse(BaseModel):
data: List[CategorySpeakers]

View File

@@ -12,7 +12,7 @@ export interface Message {
role?: string;
usage?: UsageInfo;
id?: string;
type?: 'chat' | 'voice';
type?: "chat" | "voice";
[property: string]: any;
}
@@ -32,3 +32,97 @@ export interface UsageInfo {
completion_tokens: number;
total_tokens: number;
}
/**
* Speaker 语音合成器基本信息
*/
export interface Speaker {
/** speaker唯一标识ID */
speaker_id: string;
/** speaker显示名称 */
speaker_name: string;
/** 支持的语言/口音 */
language: string;
/** 支持的平台列表 */
platforms: string[];
}
/**
* Speaker分类信息
*/
export interface CategorySpeakers {
/** 分类名称 */
category: string;
/** 该分类下的speaker列表 */
speakers: Speaker[];
}
/**
* Speaker分类枚举
*/
export enum SpeakerCategory {
/** 趣味口音 */
ACCENT = "趣味口音",
/** 角色扮演 */
ROLE_PLAY = "角色扮演"
}
/**
* 常用平台枚举
*/
export enum SpeakerPlatform {
DOUYIN = "抖音",
DOUBAO = "豆包",
CICI = "Cici",
JIANYING = "剪映",
JIANYING_C = "剪映C端",
WEB_DEMO = "web demo",
STORY_AI = "StoryAi",
MAOXIANG = "猫箱"
}
/**
* Speaker选择器组件Props
*/
export interface SpeakerSelectorProps {
/** 当前选中的speaker */
selectedSpeaker?: Speaker;
/** speaker选择回调 */
onSpeakerChange: (speaker: Speaker) => void;
/** 是否禁用 */
disabled?: boolean;
/** 过滤特定分类 */
filterCategories?: SpeakerCategory[];
/** 过滤特定平台 */
filterPlatforms?: SpeakerPlatform[];
}
/**
* 语音合成参数
*/
export interface VoiceSynthesisParams {
/** 使用的speaker */
speaker: Speaker;
/** 要合成的文本 */
text: string;
/** 语速 (0.5-2.0) */
speed?: number;
/** 音调 (0.5-2.0) */
pitch?: number;
/** 音量 (0.0-1.0) */
volume?: number;
}
/**
* 语音合成响应
*/
export interface VoiceSynthesisResponse {
/** 音频文件URL */
audio_url: string;
/** 音频时长(秒) */
duration: number;
/** 合成状态 */
status: "success" | "error";
/** 错误信息 */
error_message?: string;
}

View File

@@ -137,4 +137,9 @@ export class ChatService {
public static GetModelList(config?: AxiosRequestConfig<any>) {
return BaseClientService.get(`${this.basePath}/model/list`, config);
}
// 获取音色列表
public static GetSpeakerList(config?: AxiosRequestConfig<any>) {
return BaseClientService.get(`${this.basePath}/speaker/list`, config);
}
}

View File

@@ -1,5 +1,6 @@
import { useWebSocketStore } from "@/services";
import { convertToPCM16 } from "@/utils";
import { useChatStore } from "./chat_store";
export const useAsrStore = defineStore("asr", () => {
// 是否正在录音
@@ -125,6 +126,7 @@ export const useAsrStore = defineStore("asr", () => {
if (router.currentRoute.value.path === "/voice") {
msg.messageId = messageId;
msg.voiceConversation = true;
msg.speaker = useChatStore().speakerInfo?.speaker_id;
}
sendMessage(JSON.stringify(msg));

View File

@@ -1,7 +1,9 @@
import type {
CategorySpeakers,
IChatWithLLMRequest,
ModelInfo,
ModelListInfo,
Speaker,
UsageInfo
} from "@/interfaces";
import { ChatService } from "@/services";
@@ -20,6 +22,10 @@ export const useChatStore = defineStore("chat", () => {
const thinking = ref<boolean>(false);
// 模型列表
const modelList = ref<ModelListInfo[]>([]);
// 音色列表
const speakerList = ref<CategorySpeakers[]>([]);
// 当前音色信息
const speakerInfo = ref<Speaker | null>(null);
// 在线人数
const onlineCount = ref<number>(0);
@@ -151,6 +157,16 @@ export const useChatStore = defineStore("chat", () => {
}
};
// 获取音色列表
const getSpeakerList = async () => {
try {
const response = await ChatService.GetSpeakerList();
speakerList.value = response.data.data;
} catch (error) {
console.error("获取音色·列表失败:", error);
}
};
return {
token,
completing,
@@ -162,6 +178,9 @@ export const useChatStore = defineStore("chat", () => {
addMessageToHistory,
clearHistoryMessages,
getModelList,
onlineCount
onlineCount,
speakerList,
getSpeakerList,
speakerInfo
};
});

View File

@@ -1,5 +1,6 @@
import { useAudioWebSocket } from "@/services";
import { createAudioUrl, mergeAudioChunks } from "@/utils";
import { useChatStore } from "./chat_store";
interface AudioState {
isPlaying: boolean;
@@ -12,6 +13,7 @@ interface AudioState {
}
export const useTtsStore = defineStore("tts", () => {
const chatStore = useChatStore();
// 多音频状态管理 - 以消息ID为key
const audioStates = ref<Map<string, AudioState>>(new Map());
@@ -65,7 +67,14 @@ export const useTtsStore = defineStore("tts", () => {
hasActiveSession.value = true;
// 发送文本到TTS服务
sendMessage(JSON.stringify({ type: "tts_text", text, messageId }));
sendMessage(
JSON.stringify({
type: "tts_text",
text,
messageId,
speaker: chatStore.speakerInfo?.speaker_id
})
);
} catch (error) {
handleError(`连接失败: ${error}`, messageId);
}

View File

@@ -9,7 +9,7 @@ import markdown from "@/components/markdown.vue";
import { useAsrStore, useChatStore, useLayoutStore } from "@/stores";
const chatStore = useChatStore();
const { historyMessages, completing, modelList, modelInfo, thinking } =
const { historyMessages, completing, speakerList, speakerInfo, thinking } =
storeToRefs(chatStore);
const asrStore = useAsrStore();
const { isRecording } = storeToRefs(asrStore);
@@ -58,39 +58,43 @@ const handleItemHeaderClick = (name: string) => {
}
};
// 处理选中模型的 ID
const selectedModelId = computed({
get: () => modelInfo.value?.model_id ?? null,
// 处理选中speaker的 ID
const selectedSpeakerId = computed({
get: () => speakerInfo.value?.speaker_id ?? null,
set: (id: string | null) => {
for (const vendor of modelList.value) {
const found = vendor.models.find((model) => model.model_id === id);
for (const category of speakerList.value) {
const found = category.speakers.find(
(speaker) => speaker.speaker_id === id
);
if (found) {
modelInfo.value = found;
speakerInfo.value = found;
return;
}
}
modelInfo.value = null;
speakerInfo.value = null;
}
});
// 监听模型列表变化,更新选项
// 监听speaker列表变化,更新选项
watch(
() => modelList.value,
() => speakerList.value,
(newVal) => {
if (newVal) {
options.value = newVal.map((vendor) => ({
options.value = newVal.map((category) => ({
type: "group",
label: vendor.vendor,
key: vendor.vendor,
children: vendor.models.map((model) => ({
label: model.model_name,
value: model.model_id,
type: model.model_type
label: category.category,
key: category.category,
children: category.speakers.map((speaker) => ({
label: speaker.speaker_name,
value: speaker.speaker_id,
language: speaker.language,
platforms: speaker.platforms
}))
}));
if (newVal.length > 0 && newVal[0].models.length > 0) {
modelInfo.value = newVal[0].models[0];
// 默认选择第一个speaker
if (newVal.length > 0 && newVal[0].speakers.length > 0) {
speakerInfo.value = newVal[0].speakers[0];
}
}
},
@@ -115,7 +119,7 @@ watch(completing, (newVal) => {
});
onMounted(() => {
chatStore.getModelList();
chatStore.getSpeakerList();
});
</script>
@@ -207,7 +211,7 @@ onMounted(() => {
<div class="flex justify-between items-center gap-2">
<div class="flex items-center gap-2">
<NSelect
v-model:value="selectedModelId"
v-model:value="selectedSpeakerId"
label-field="label"
value-field="value"
children-field="children"