func -> call

This commit is contained in:
Ftps
2024-06-02 02:30:11 +09:00
parent 5e3dce38f4
commit 043fa4d750
6 changed files with 50 additions and 18 deletions

14
.env Normal file
View File

@@ -0,0 +1,14 @@
OPENBLAS_NUM_THREADS = 1
no_proxy = localhost, 127.0.0.1, ::1
# You can change the location of the model, etc. by changing here
weight_root = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/assets/weights
weight_uvr5_root = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/assets/uvr5_weights
index_root = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/logs
rmvpe_root = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/assets/rmvpe
hubert_path = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/assets/hubert/hubert_base.pt
hubert_path_ = /Users/ftps/Downloads/Hubert Base.pt
save_uvr_path = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/opt
TEMP = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/TEMP
pretrained = /Users/ftps/Retrieval-based-Voice-Conversion-WebUI/assets/pretrained
exp_dir =

View File

@@ -20,7 +20,7 @@ class UVR:
self.need_reformat: bool = True
self.config: Config = Config()
def uvr_wrapper(
def __call__(
self,
audio_path: Path,
agg: int = 10,

View File

@@ -11,6 +11,7 @@ from base64 import b64encode
from rvc.modules.vc.modules import VC
import glob
import os
import torch
router = APIRouter()
from dotenv import load_dotenv
@@ -43,7 +44,6 @@ def inference(
rms_mix_rate: float = 0.25,
protect: float = 0.33,
):
print(res_type)
vc = VC()
vc.get_vc(modelpath)
tgt_sr, audio_opt, times, _ = vc.vc_inference(
@@ -61,6 +61,10 @@ def inference(
)
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
print(times)
if torch.cuda.is_available():
torch.cuda.empty_cache()
if res_type == "blob":
return responses.StreamingResponse(
wv,

View File

@@ -1,18 +1,32 @@
from fastapi import APIRouter, Response, UploadFile, responses
from io import BytesIO
from fastapi import APIRouter, UploadFile, responses, Query
from fastapi.responses import JSONResponse
from rvc.modules.uvr5.modules import UVR
from base64 import b64encode
from scipy.io import wavfile
router = APIRouter()
@router.post("/inference")
def uvr(inputpath, outputpath, modelname, format):
uvr_module = UVR()
uvr_module.uvr_wrapper(
inputpath, outputpath, model_name=modelname, export_format=format
)
return responses.StreamingResponse(
audio,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=inference.wav"},
)
def uvr(
inputpath,
outputpath,
modelname,
res_type: str = Query("blob", enum=["blob", "json"]),
):
arries = [i for i in UVR()(inputpath, outputpath, model_name=modelname)]
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
if res_type == "blob":
return responses.StreamingResponse(
wv,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=inference.wav"},
)
else:
return JSONResponse(
{
"audio": b64encode(wv.read()).decode("utf-8"),
}
)

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import click
from dotenv import load_dotenv
from scipy.io import wavfile
import torch
logging.getLogger("numba").setLevel(logging.WARNING)
@@ -129,4 +130,6 @@ def infer(
wavfile.write(outputpath, tgt_sr, audio_opt)
click.echo(times)
click.echo(f"Finish inference. Check {outputpath}")
return tgt_sr, audio_opt, times
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -38,8 +38,5 @@ from rvc.modules.uvr5.modules import UVR
help="output Format",
)
def uvr(modelname, inputpath, outputpath, format):
uvr_module = UVR()
uvr_module.uvr_wrapper(
inputpath, outputpath, model_name=modelname, export_format=format
)
UVR()(inputpath, outputpath, model_name=modelname, export_format=format)
click.echo(f"Finish uvr5. Check {outputpath}")