add use_device() funcion
can switch to a specific device
This commit is contained in:
1
rvc/configs/__init__.py
Normal file
1
rvc/configs/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from rvc.configs.config import Config
|
||||||
@@ -105,59 +105,7 @@ class Config:
|
|||||||
def has_xpu() -> bool:
|
def has_xpu() -> bool:
|
||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
def use_fp32_config(self) -> None:
|
def params_config(self) -> tuple:
|
||||||
for config_file, data in self.json_config.items():
|
|
||||||
try:
|
|
||||||
data["train"]["fp16_run"] = False
|
|
||||||
with open(config_file, "w") as json_file:
|
|
||||||
json.dump(data, json_file, indent=4)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Error updating {config_file}: {str(e)}")
|
|
||||||
logger.info("overwrite configs.json")
|
|
||||||
|
|
||||||
def device_config(self) -> tuple:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
if self.has_xpu():
|
|
||||||
self.device = self.instead = "xpu:0"
|
|
||||||
self.is_half = True
|
|
||||||
i_device = int(self.device.split(":")[-1])
|
|
||||||
self.gpu_name = torch.cuda.get_device_name(i_device)
|
|
||||||
if (
|
|
||||||
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
|
||||||
or "P40" in self.gpu_name.upper()
|
|
||||||
or "P10" in self.gpu_name.upper()
|
|
||||||
or "1060" in self.gpu_name
|
|
||||||
or "1070" in self.gpu_name
|
|
||||||
or "1080" in self.gpu_name
|
|
||||||
):
|
|
||||||
logger.info(f"Found GPU {self.gpu_name}, force to fp32")
|
|
||||||
self.is_half = False
|
|
||||||
self.use_fp32_config()
|
|
||||||
else:
|
|
||||||
logger.info(f"Found GPU {self.gpu_name}")
|
|
||||||
self.gpu_mem = int(
|
|
||||||
torch.cuda.get_device_properties(i_device).total_memory
|
|
||||||
/ 1024
|
|
||||||
/ 1024
|
|
||||||
/ 1024
|
|
||||||
+ 0.4
|
|
||||||
)
|
|
||||||
elif self.has_mps():
|
|
||||||
logger.info("No supported Nvidia GPU found")
|
|
||||||
self.device = self.instead = "mps"
|
|
||||||
self.is_half = False
|
|
||||||
self.use_fp32_config()
|
|
||||||
elif self.dml:
|
|
||||||
import torch_directml
|
|
||||||
|
|
||||||
self.device = torch_directml.device(torch_directml.default_device())
|
|
||||||
self.is_half = False
|
|
||||||
else:
|
|
||||||
logger.info("No supported Nvidia GPU found")
|
|
||||||
self.device = self.instead = "cpu"
|
|
||||||
self.is_half = False
|
|
||||||
self.use_fp32_config()
|
|
||||||
|
|
||||||
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
||||||
x_pad = 1
|
x_pad = 1
|
||||||
x_query = 5
|
x_query = 5
|
||||||
@@ -175,7 +123,75 @@ class Config:
|
|||||||
x_query = 6
|
x_query = 6
|
||||||
x_center = 38
|
x_center = 38
|
||||||
x_max = 41
|
x_max = 41
|
||||||
|
return x_pad, x_query, x_center, x_max
|
||||||
|
|
||||||
|
def use_cuda(self) -> None:
|
||||||
|
if self.has_xpu():
|
||||||
|
self.device = self.instead = "xpu:0"
|
||||||
|
self.is_half = True
|
||||||
|
i_device = int(self.device.split(":")[-1])
|
||||||
|
self.gpu_name = torch.cuda.get_device_name(i_device)
|
||||||
|
if (
|
||||||
|
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
||||||
|
or "P40" in self.gpu_name.upper()
|
||||||
|
or "P10" in self.gpu_name.upper()
|
||||||
|
or "1060" in self.gpu_name
|
||||||
|
or "1070" in self.gpu_name
|
||||||
|
or "1080" in self.gpu_name
|
||||||
|
):
|
||||||
|
logger.info(f"Found GPU {self.gpu_name}, force to fp32")
|
||||||
|
self.is_half = False
|
||||||
|
self.use_fp32_config()
|
||||||
|
else:
|
||||||
|
logger.info(f"Found GPU {self.gpu_name}")
|
||||||
|
self.gpu_mem = int(
|
||||||
|
torch.cuda.get_device_properties(i_device).total_memory / 1024 / 1024 / 1024
|
||||||
|
+ 0.4
|
||||||
|
)
|
||||||
|
|
||||||
|
def use_mps(self) -> None:
|
||||||
|
self.device = self.instead = "mps"
|
||||||
|
self.is_half = False
|
||||||
|
self.use_fp32_config()
|
||||||
|
self.params_config()
|
||||||
|
|
||||||
|
def use_dml(self) -> None:
|
||||||
|
import torch_directml
|
||||||
|
|
||||||
|
self.device = torch_directml.device(torch_directml.default_device())
|
||||||
|
self.is_half = False
|
||||||
|
self.params_config()
|
||||||
|
|
||||||
|
def use_cpu(self) -> None:
|
||||||
|
self.device = self.instead = "cpu"
|
||||||
|
self.is_half = False
|
||||||
|
self.use_fp32_config()
|
||||||
|
self.params_config()
|
||||||
|
|
||||||
|
def use_fp32_config(self) -> None:
|
||||||
|
for config_file, data in self.json_config.items():
|
||||||
|
try:
|
||||||
|
data["train"]["fp16_run"] = False
|
||||||
|
with open(config_file, "w") as json_file:
|
||||||
|
json.dump(data, json_file, indent=4)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Error updating {config_file}: {str(e)}")
|
||||||
|
logger.info("overwrite configs.json")
|
||||||
|
|
||||||
|
def device_config(self) -> tuple:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.use_cuda()
|
||||||
|
elif self.has_mps():
|
||||||
|
logger.info("No supported Nvidia GPU found")
|
||||||
|
self.use_mps()
|
||||||
|
elif self.dml:
|
||||||
|
self.use_dml()
|
||||||
|
else:
|
||||||
|
logger.info("No supported Nvidia GPU found")
|
||||||
|
self.device = self.instead = "cpu"
|
||||||
|
self.is_half = False
|
||||||
|
self.use_fp32_config()
|
||||||
|
|
||||||
logger.info(f"Use {self.dml or self.instead} instead")
|
logger.info(f"Use {self.dml or self.instead} instead")
|
||||||
logger.info(f"is_half:{self.is_half}, device:{self.device}")
|
logger.info(f"is_half:{self.is_half}, device:{self.device}")
|
||||||
return x_pad, x_query, x_center, x_max
|
return self.params_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user