Files
Retrieval-based-Voice-Conve…/rvc/wrapper/api/endpoints/inference.py

193 lines
5.2 KiB
Python

import json
import logging
import tempfile
from io import BytesIO
from pathlib import Path
from fastapi import APIRouter, Response, UploadFile, Body, responses, Form, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from scipy.io import wavfile
from base64 import b64encode
from rvc.modules.vc.modules import VC
import glob
import os
import soundfile as sf
logger = logging.getLogger(__name__)
router = APIRouter()
from dotenv import load_dotenv
load_dotenv()
@router.post("/inference")
def inference(
input_audio: Path | UploadFile,
modelpath: Path
| UploadFile = Body(
...,
enum=[
os.path.basename(file)
for file in glob.glob(f"{os.getenv('weight_root')}/*")
],
),
res_type: str = Query("blob", enum=["blob", "json"]),
sid: int = 0,
f0_up_key: int = 0,
f0_method: str = Query(
"rmvpe", enum=["pm", "harvest", "dio", "rmvpe", "rmvpe_gpu"]
),
f0_file: Path | None = None,
index_file: Path | None = None,
index_rate: float = 0.75,
filter_radius: int = 3,
resample_sr: int = 0,
rms_mix_rate: float = 0.25,
protect: float = 0.33,
):
print(res_type)
vc = VC()
vc.get_vc(modelpath)
tgt_sr, audio_opt, times, _ = vc.vc_inference(
sid,
input_audio,
f0_up_key,
f0_method,
f0_file,
index_file,
index_rate,
filter_radius,
resample_sr,
rms_mix_rate,
protect,
)
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
print(times)
if res_type == "blob":
return responses.StreamingResponse(
wv,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=inference.wav"},
)
else:
return JSONResponse(
{
"time": json.loads(json.dumps(times)),
"audio": b64encode(wv.read()).decode("utf-8"),
}
)
@router.post("/tts-inference")
def tts_inference(
text: str = Body(..., description="The text to synthesize"),
language: str = Body(
"Chinese",
description="Language code",
enum=[
"Chinese",
"English",
"Japanese",
"Korean",
"German",
"French",
"Russian",
"Portuguese",
"Spanish",
"Italian",
],
),
speaker: str = Body("Vivian", description="Speaker/voice profile name"),
instruct: str = Body("", description="Natural language instruction for controlling timbre, emotion, and prosody"),
modelpath: Path
| UploadFile = Body(
...,
enum=[
os.path.basename(file)
for file in glob.glob(f"{os.getenv('weight_root')}/*")
],
),
res_type: str = Query("blob", enum=["blob", "json"]),
sid: int = 0,
f0_up_key: int = 0,
f0_method: str = Query(
"rmvpe", enum=["pm", "harvest", "dio", "rmvpe", "rmvpe_gpu"]
),
f0_file: Path | None = None,
index_file: Path | None = None,
index_rate: float = 0.75,
filter_radius: int = 3,
resample_sr: int = 0,
rms_mix_rate: float = 0.25,
protect: float = 0.33,
):
"""
Perform TTS using Qwen3-TTS followed by voice conversion inference.
First generates speech from text using Qwen3-TTS, then applies voice conversion
to transform the generated speech to the target voice.
"""
from qwen_tts import Qwen3TTSModel
import torch
# Load Qwen3-TTS model
tts_model = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
)
# Generate TTS audio
wavs, sr = tts_model.generate_custom_voice(
text=text,
language=language,
speaker=speaker,
instruct=instruct,
)
# Save TTS output to temporary file
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_path = tmp.name
sf.write(tmp_path, wavs[0], sr)
tmp.close()
try:
# Run voice conversion on the generated audio
vc = VC()
vc.get_vc(modelpath)
tgt_sr, audio_opt, times, _ = vc.vc_inference(
sid,
tmp_path,
f0_up_key,
f0_method,
f0_file,
index_file,
index_rate,
filter_radius,
resample_sr,
rms_mix_rate,
protect,
)
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
print(times)
if res_type == "blob":
return responses.StreamingResponse(
wv,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=tts_inference.wav"},
)
else:
return JSONResponse(
{
"time": json.loads(json.dumps(times)),
"audio": b64encode(wv.read()).decode("utf-8"),
}
)
finally:
# Clean up temporary file
os.unlink(tmp_path)