From efed797200e156d22027af6b301755bd4e7c67e7 Mon Sep 17 00:00:00 2001 From: Ftps Date: Tue, 26 Mar 2024 23:24:35 +0900 Subject: [PATCH] API: add response with json --- rvc/modules/vc/modules.py | 4 ++-- rvc/modules/vc/utils.py | 2 +- rvc/wrapper/api/endpoints/inference.py | 31 +++++++++++++++++--------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/rvc/modules/vc/modules.py b/rvc/modules/vc/modules.py index 55812fe..54ed383 100644 --- a/rvc/modules/vc/modules.py +++ b/rvc/modules/vc/modules.py @@ -36,8 +36,8 @@ class VC: self.config = Config() - def get_vc(self, sid: str, *to_return_protect: int): - logger.info("Get sid: " + sid) + def get_vc(self, sid: str | Path, *to_return_protect: int): + logger.info("Get sid: " + os.path.basename(sid)) return_protect = [ to_return_protect[0] if self.if_f0 != 0 and to_return_protect else 0.5, diff --git a/rvc/modules/vc/utils.py b/rvc/modules/vc/utils.py index 94f6f53..47c590e 100644 --- a/rvc/modules/vc/utils.py +++ b/rvc/modules/vc/utils.py @@ -13,7 +13,7 @@ def get_index_path_from_model(sid): for name in files if name.endswith(".index") and "trained" not in name ] - if sid.split(".")[0] in f + if str(sid).split(".")[0] in f ), "", ) diff --git a/rvc/wrapper/api/endpoints/inference.py b/rvc/wrapper/api/endpoints/inference.py index 42ae103..1d45225 100644 --- a/rvc/wrapper/api/endpoints/inference.py +++ b/rvc/wrapper/api/endpoints/inference.py @@ -1,10 +1,13 @@ +import json from io import BytesIO from pathlib import Path -from fastapi import APIRouter, Response, UploadFile, responses +from fastapi import APIRouter, Response, UploadFile, Body, responses, Form, Query +from fastapi.responses import JSONResponse + from pydantic import BaseModel from scipy.io import wavfile - +from base64 import b64encode from rvc.modules.vc.modules import VC router = APIRouter() @@ -12,11 +15,12 @@ router = APIRouter() @router.post("/inference") def inference( - modelpath: str | UploadFile, - input: Path | UploadFile, + modelpath: Path | UploadFile, + input_audio: Path | UploadFile, + res_type: str = Query("blob", enum=["blob", "json"]), sid: int = 0, f0_up_key: int = 0, - f0_method: str = "rmvpe", + f0_method: str = Query("rmvpe", enum=["pm", "harvest", "dio", "rmvpe", "rmvpe_gpu"]), f0_file: Path | None = None, index_file: Path | None = None, index_rate: float = 0.75, @@ -25,11 +29,12 @@ def inference( rms_mix_rate: float = 0.25, protect: float = 0.33, ): + print(res_type) vc = VC() vc.get_vc(modelpath) tgt_sr, audio_opt, times, _ = vc.vc_inference( sid, - input, + input_audio, f0_up_key, f0_method, f0_file, @@ -42,8 +47,12 @@ def inference( ) wavfile.write(wv := BytesIO(), tgt_sr, audio_opt) print(times) - return responses.StreamingResponse( - wv, - media_type="audio/wav", - headers={"Content-Disposition": "attachment; filename=inference.wav"}, - ) + if res_type == "blob": + return responses.StreamingResponse( + wv, + media_type="audio/wav", + headers={"Content-Disposition": "attachment; filename=inference.wav"}, + ) + else: + return JSONResponse({"time": json.loads(json.dumps(times)), "audio": b64encode(wv.read()).decode('utf-8')}) +