diff --git a/rvc/configs/config.py b/rvc/configs/config.py index 8cd6012..1c0ba38 100644 --- a/rvc/configs/config.py +++ b/rvc/configs/config.py @@ -20,12 +20,11 @@ import logging logger = logging.getLogger(__name__) -version_config_list = [ - "v1/32k.json", - "v1/40k.json", - "v1/48k.json", - "v2/48k.json", - "v2/32k.json", +version_config_list: list = [ + os.path.join(root, file) + for root, dirs, files in os.walk(os.path.dirname(os.path.abspath(__file__))) + for file in files + if file.endswith(".json") ] @@ -62,11 +61,10 @@ class Config: @staticmethod def load_config_json() -> dict: - d = {} - for config_file in version_config_list: - with open(f"configs/{config_file}", "r") as f: - d[config_file] = json.load(f) - return d + return { + config_file: json.load(open(config_file, "r")) + for config_file in version_config_list + } @staticmethod def arg_parse() -> tuple: @@ -120,18 +118,15 @@ class Config: else: return False - def use_fp32_config(self): - for config_file in version_config_list: - self.json_config[config_file]["train"]["fp16_run"] = False - with open(f"configs/{config_file}", "r") as f: - strr = f.read().replace("true", "false") - with open(f"configs/{config_file}", "w") as f: - f.write(strr) - with open("infer/modules/train/preprocess.py", "r") as f: - strr = f.read().replace("3.7", "3.0") - with open("infer/modules/train/preprocess.py", "w") as f: - f.write(strr) - print("overwrite preprocess and configs.json") + def use_fp32_config(self) -> None: + for config_file, data in self.json_config.items(): + try: + data["train"]["fp16_run"] = False + with open(config_file, "w") as json_file: + json.dump(data, json_file, indent=4) + except Exception as e: + logger.info(f"Error updating {config_file}: {str(e)}") + logger.info("overwrite configs.json") def device_config(self) -> tuple: if torch.cuda.is_available():