func -> call
This commit is contained in:
@@ -11,6 +11,7 @@ from base64 import b64encode
|
||||
from rvc.modules.vc.modules import VC
|
||||
import glob
|
||||
import os
|
||||
import torch
|
||||
|
||||
router = APIRouter()
|
||||
from dotenv import load_dotenv
|
||||
@@ -43,7 +44,6 @@ def inference(
|
||||
rms_mix_rate: float = 0.25,
|
||||
protect: float = 0.33,
|
||||
):
|
||||
print(res_type)
|
||||
vc = VC()
|
||||
vc.get_vc(modelpath)
|
||||
tgt_sr, audio_opt, times, _ = vc.vc_inference(
|
||||
@@ -61,6 +61,10 @@ def inference(
|
||||
)
|
||||
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
|
||||
print(times)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if res_type == "blob":
|
||||
return responses.StreamingResponse(
|
||||
wv,
|
||||
|
||||
@@ -1,18 +1,32 @@
|
||||
from fastapi import APIRouter, Response, UploadFile, responses
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import APIRouter, UploadFile, responses, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from rvc.modules.uvr5.modules import UVR
|
||||
from base64 import b64encode
|
||||
from scipy.io import wavfile
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/inference")
|
||||
def uvr(inputpath, outputpath, modelname, format):
|
||||
uvr_module = UVR()
|
||||
uvr_module.uvr_wrapper(
|
||||
inputpath, outputpath, model_name=modelname, export_format=format
|
||||
)
|
||||
return responses.StreamingResponse(
|
||||
audio,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=inference.wav"},
|
||||
)
|
||||
def uvr(
|
||||
inputpath,
|
||||
outputpath,
|
||||
modelname,
|
||||
res_type: str = Query("blob", enum=["blob", "json"]),
|
||||
):
|
||||
arries = [i for i in UVR()(inputpath, outputpath, model_name=modelname)]
|
||||
wavfile.write(wv := BytesIO(), tgt_sr, audio_opt)
|
||||
if res_type == "blob":
|
||||
return responses.StreamingResponse(
|
||||
wv,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=inference.wav"},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
{
|
||||
"audio": b64encode(wv.read()).decode("utf-8"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
import click
|
||||
from dotenv import load_dotenv
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
@@ -129,4 +130,6 @@ def infer(
|
||||
wavfile.write(outputpath, tgt_sr, audio_opt)
|
||||
click.echo(times)
|
||||
click.echo(f"Finish inference. Check {outputpath}")
|
||||
return tgt_sr, audio_opt, times
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -38,8 +38,5 @@ from rvc.modules.uvr5.modules import UVR
|
||||
help="output Format",
|
||||
)
|
||||
def uvr(modelname, inputpath, outputpath, format):
|
||||
uvr_module = UVR()
|
||||
uvr_module.uvr_wrapper(
|
||||
inputpath, outputpath, model_name=modelname, export_format=format
|
||||
)
|
||||
UVR()(inputpath, outputpath, model_name=modelname, export_format=format)
|
||||
click.echo(f"Finish uvr5. Check {outputpath}")
|
||||
|
||||
Reference in New Issue
Block a user