Optimization of conditional branching
This commit is contained in:
@@ -49,7 +49,6 @@ class Config:
|
|||||||
self.noautoopen,
|
self.noautoopen,
|
||||||
self.dml,
|
self.dml,
|
||||||
) = self.arg_parse()
|
) = self.arg_parse()
|
||||||
self.instead = ""
|
|
||||||
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -144,92 +143,40 @@ class Config:
|
|||||||
/ 1024
|
/ 1024
|
||||||
+ 0.4
|
+ 0.4
|
||||||
)
|
)
|
||||||
if self.gpu_mem <= 4:
|
|
||||||
with open("infer/modules/train/preprocess.py", "r") as f:
|
|
||||||
strr = f.read().replace("3.7", "3.0")
|
|
||||||
with open("infer/modules/train/preprocess.py", "w") as f:
|
|
||||||
f.write(strr)
|
|
||||||
elif self.has_mps():
|
elif self.has_mps():
|
||||||
logger.info("No supported Nvidia GPU found")
|
logger.info("No supported Nvidia GPU found")
|
||||||
self.device = self.instead = "mps"
|
self.device = self.instead = "mps"
|
||||||
self.is_half = False
|
self.is_half = False
|
||||||
self.use_fp32_config()
|
self.use_fp32_config()
|
||||||
|
elif self.dml:
|
||||||
|
import torch_directml
|
||||||
|
|
||||||
|
self.device = torch_directml.device(torch_directml.default_device())
|
||||||
|
self.is_half = False
|
||||||
else:
|
else:
|
||||||
logger.info("No supported Nvidia GPU found")
|
logger.info("No supported Nvidia GPU found")
|
||||||
self.device = self.instead = "cpu"
|
self.device = self.instead = "cpu"
|
||||||
self.is_half = False
|
self.is_half = False
|
||||||
self.use_fp32_config()
|
self.use_fp32_config()
|
||||||
|
|
||||||
if self.n_cpu == 0:
|
|
||||||
self.n_cpu = cpu_count()
|
|
||||||
|
|
||||||
if self.is_half:
|
|
||||||
# 6G显存配置
|
|
||||||
x_pad = 3
|
|
||||||
x_query = 10
|
|
||||||
x_center = 60
|
|
||||||
x_max = 65
|
|
||||||
else:
|
|
||||||
# 5G显存配置
|
|
||||||
x_pad = 1
|
|
||||||
x_query = 6
|
|
||||||
x_center = 38
|
|
||||||
x_max = 41
|
|
||||||
|
|
||||||
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
||||||
x_pad = 1
|
x_pad = 1
|
||||||
x_query = 5
|
x_query = 5
|
||||||
x_center = 30
|
x_center = 30
|
||||||
x_max = 32
|
x_max = 32
|
||||||
if self.dml:
|
elif self.is_half:
|
||||||
logger.info("Use DirectML instead")
|
# 6G PU_RAM conf
|
||||||
if (
|
x_pad = 3
|
||||||
os.path.exists(
|
x_query = 10
|
||||||
"runtime\Lib\site-packages\onnxruntime\capi\DirectML.dll"
|
x_center = 60
|
||||||
)
|
x_max = 65
|
||||||
== False
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
os.rename(
|
|
||||||
"runtime\Lib\site-packages\onnxruntime",
|
|
||||||
"runtime\Lib\site-packages\onnxruntime-cuda",
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
os.rename(
|
|
||||||
"runtime\Lib\site-packages\onnxruntime-dml",
|
|
||||||
"runtime\Lib\site-packages\onnxruntime",
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
# if self.device != "cpu":
|
|
||||||
import torch_directml
|
|
||||||
|
|
||||||
self.device = torch_directml.device(torch_directml.default_device())
|
|
||||||
self.is_half = False
|
|
||||||
else:
|
else:
|
||||||
if self.instead:
|
# 5G GPU_RAM conf
|
||||||
logger.info(f"Use {self.instead} instead")
|
x_pad = 1
|
||||||
if (
|
x_query = 6
|
||||||
os.path.exists(
|
x_center = 38
|
||||||
"runtime\Lib\site-packages\onnxruntime\capi\onnxruntime_providers_cuda.dll"
|
x_max = 41
|
||||||
)
|
|
||||||
== False
|
logger.info(f"Use {self.dml or self.instead} instead")
|
||||||
):
|
logger.info(f"is_half:{self.is_half}, device:{self.device}")
|
||||||
try:
|
|
||||||
os.rename(
|
|
||||||
"runtime\Lib\site-packages\onnxruntime",
|
|
||||||
"runtime\Lib\site-packages\onnxruntime-dml",
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
os.rename(
|
|
||||||
"runtime\Lib\site-packages\onnxruntime-cuda",
|
|
||||||
"runtie\Lib\site-packages\onnxruntime",
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
print("is_half:%s, device:%s" % (self.is_half, self.device))
|
|
||||||
return x_pad, x_query, x_center, x_max
|
return x_pad, x_query, x_center, x_max
|
||||||
|
|||||||
Reference in New Issue
Block a user