add use_device() funcion
can switch to a specific device
This commit is contained in:
1
rvc/configs/__init__.py
Normal file
1
rvc/configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from rvc.configs.config import Config
|
||||
@@ -105,59 +105,7 @@ class Config:
|
||||
def has_xpu() -> bool:
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
def use_fp32_config(self) -> None:
|
||||
for config_file, data in self.json_config.items():
|
||||
try:
|
||||
data["train"]["fp16_run"] = False
|
||||
with open(config_file, "w") as json_file:
|
||||
json.dump(data, json_file, indent=4)
|
||||
except Exception as e:
|
||||
logger.info(f"Error updating {config_file}: {str(e)}")
|
||||
logger.info("overwrite configs.json")
|
||||
|
||||
def device_config(self) -> tuple:
|
||||
if torch.cuda.is_available():
|
||||
if self.has_xpu():
|
||||
self.device = self.instead = "xpu:0"
|
||||
self.is_half = True
|
||||
i_device = int(self.device.split(":")[-1])
|
||||
self.gpu_name = torch.cuda.get_device_name(i_device)
|
||||
if (
|
||||
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
||||
or "P40" in self.gpu_name.upper()
|
||||
or "P10" in self.gpu_name.upper()
|
||||
or "1060" in self.gpu_name
|
||||
or "1070" in self.gpu_name
|
||||
or "1080" in self.gpu_name
|
||||
):
|
||||
logger.info(f"Found GPU {self.gpu_name}, force to fp32")
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
else:
|
||||
logger.info(f"Found GPU {self.gpu_name}")
|
||||
self.gpu_mem = int(
|
||||
torch.cuda.get_device_properties(i_device).total_memory
|
||||
/ 1024
|
||||
/ 1024
|
||||
/ 1024
|
||||
+ 0.4
|
||||
)
|
||||
elif self.has_mps():
|
||||
logger.info("No supported Nvidia GPU found")
|
||||
self.device = self.instead = "mps"
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
elif self.dml:
|
||||
import torch_directml
|
||||
|
||||
self.device = torch_directml.device(torch_directml.default_device())
|
||||
self.is_half = False
|
||||
else:
|
||||
logger.info("No supported Nvidia GPU found")
|
||||
self.device = self.instead = "cpu"
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
|
||||
def params_config(self) -> tuple:
|
||||
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
||||
x_pad = 1
|
||||
x_query = 5
|
||||
@@ -175,7 +123,75 @@ class Config:
|
||||
x_query = 6
|
||||
x_center = 38
|
||||
x_max = 41
|
||||
return x_pad, x_query, x_center, x_max
|
||||
|
||||
def use_cuda(self) -> None:
|
||||
if self.has_xpu():
|
||||
self.device = self.instead = "xpu:0"
|
||||
self.is_half = True
|
||||
i_device = int(self.device.split(":")[-1])
|
||||
self.gpu_name = torch.cuda.get_device_name(i_device)
|
||||
if (
|
||||
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
||||
or "P40" in self.gpu_name.upper()
|
||||
or "P10" in self.gpu_name.upper()
|
||||
or "1060" in self.gpu_name
|
||||
or "1070" in self.gpu_name
|
||||
or "1080" in self.gpu_name
|
||||
):
|
||||
logger.info(f"Found GPU {self.gpu_name}, force to fp32")
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
else:
|
||||
logger.info(f"Found GPU {self.gpu_name}")
|
||||
self.gpu_mem = int(
|
||||
torch.cuda.get_device_properties(i_device).total_memory / 1024 / 1024 / 1024
|
||||
+ 0.4
|
||||
)
|
||||
|
||||
def use_mps(self) -> None:
|
||||
self.device = self.instead = "mps"
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
self.params_config()
|
||||
|
||||
def use_dml(self) -> None:
|
||||
import torch_directml
|
||||
|
||||
self.device = torch_directml.device(torch_directml.default_device())
|
||||
self.is_half = False
|
||||
self.params_config()
|
||||
|
||||
def use_cpu(self) -> None:
|
||||
self.device = self.instead = "cpu"
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
self.params_config()
|
||||
|
||||
def use_fp32_config(self) -> None:
|
||||
for config_file, data in self.json_config.items():
|
||||
try:
|
||||
data["train"]["fp16_run"] = False
|
||||
with open(config_file, "w") as json_file:
|
||||
json.dump(data, json_file, indent=4)
|
||||
except Exception as e:
|
||||
logger.info(f"Error updating {config_file}: {str(e)}")
|
||||
logger.info("overwrite configs.json")
|
||||
|
||||
def device_config(self) -> tuple:
|
||||
if torch.cuda.is_available():
|
||||
self.use_cuda()
|
||||
elif self.has_mps():
|
||||
logger.info("No supported Nvidia GPU found")
|
||||
self.use_mps()
|
||||
elif self.dml:
|
||||
self.use_dml()
|
||||
else:
|
||||
logger.info("No supported Nvidia GPU found")
|
||||
self.device = self.instead = "cpu"
|
||||
self.is_half = False
|
||||
self.use_fp32_config()
|
||||
|
||||
logger.info(f"Use {self.dml or self.instead} instead")
|
||||
logger.info(f"is_half:{self.is_half}, device:{self.device}")
|
||||
return x_pad, x_query, x_center, x_max
|
||||
return self.params_config()
|
||||
|
||||
Reference in New Issue
Block a user