rewrite vr-architecture
This commit is contained in:
+19
-34
@@ -10,7 +10,7 @@ from pydub import AudioSegment
|
||||
|
||||
from rvc.configs.config import Config
|
||||
from rvc.modules.uvr5.mdxnet import MDXNetDereverb
|
||||
from rvc.modules.uvr5.vr import AudioPre, AudioPreDeEcho
|
||||
from rvc.modules.uvr5.vr import AudioPreprocess
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,37 +23,18 @@ class UVR:
|
||||
def uvr_wrapper(
|
||||
self,
|
||||
audio_path: Path,
|
||||
save_vocal_path: Path | None = None,
|
||||
save_ins_path: Path | None = None,
|
||||
agg: int = 10,
|
||||
export_format: str = "flac",
|
||||
model_name: str | None = None,
|
||||
temp_path: Path | None = None,
|
||||
temp_dir: Path | None = None,
|
||||
):
|
||||
infos = []
|
||||
save_vocal_path = (
|
||||
os.getenv("save_uvr_path") if not save_vocal_path else save_vocal_path
|
||||
)
|
||||
save_ins_path = (
|
||||
os.getenv("save_uvr_path") if not save_ins_path else save_ins_path
|
||||
)
|
||||
|
||||
infos = list()
|
||||
if model_name is None:
|
||||
model_name = os.path.basename(glob(f"{os.getenv('weight_uvr5_root')}/*")[0])
|
||||
is_hp3 = "HP3" in model_name
|
||||
|
||||
if model_name == "onnx_dereverb_By_FoxJoy":
|
||||
pre_fun = MDXNetDereverb(15, self.config.device)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
pre_fun = func(
|
||||
agg=int(agg),
|
||||
model_path=os.path.join(
|
||||
os.getenv("weight_uvr5_root"), model_name # + ".pth"
|
||||
),
|
||||
device=self.config.device,
|
||||
is_half=self.config.is_half,
|
||||
)
|
||||
pre_fun = AudioPreprocess(
|
||||
os.path.join(os.getenv("weight_uvr5_root"), model_name), # + ".pth"
|
||||
int(agg),
|
||||
)
|
||||
|
||||
process_paths = (
|
||||
[
|
||||
@@ -65,12 +46,14 @@ class UVR:
|
||||
else audio_path
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for process_path in [process_paths]:
|
||||
print(f"path: {process_path}")
|
||||
info = sf.info(process_path)
|
||||
if not (info.channels == 2 and info.samplerate == "44100"):
|
||||
tmp_path = os.path.join(
|
||||
temp_path or os.environ.get("TEMP"), os.path.basename(process_path)
|
||||
temp_dir or os.environ.get("TEMP"), os.path.basename(process_path)
|
||||
)
|
||||
AudioSegment.from_file(process_path).export(
|
||||
tmp_path,
|
||||
@@ -80,14 +63,16 @@ class UVR:
|
||||
parameters=["-ar", "44100"],
|
||||
)
|
||||
|
||||
pre_fun._path_audio_(
|
||||
process_path,
|
||||
save_vocal_path,
|
||||
save_ins_path,
|
||||
export_format,
|
||||
is_hp3=is_hp3,
|
||||
results.append(
|
||||
|
||||
pre_fun.process(
|
||||
tmp_path or process_path,
|
||||
)
|
||||
|
||||
)
|
||||
infos.append(f"{os.path.basename(process_path)}->Success")
|
||||
yield "\n".join(infos)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user