This commit is contained in:
Ftps
2023-11-19 04:28:56 +09:00
parent d1106fdd90
commit 59ddaacad9
3 changed files with 1 additions and 5 deletions

View File

@@ -135,9 +135,7 @@ def ipex_init():
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception:
try:
from .gradscaler import (
gradscaler_init,
)
from .gradscaler import gradscaler_init
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler

View File

@@ -1,7 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
original_torch_bmm = torch.bmm

View File

@@ -4,7 +4,6 @@ import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch._C as core
import torch
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = (