format
This commit is contained in:
@@ -135,9 +135,7 @@ def ipex_init():
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception:
|
||||
try:
|
||||
from .gradscaler import (
|
||||
gradscaler_init,
|
||||
)
|
||||
from .gradscaler import gradscaler_init
|
||||
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import intel_extension_for_pytorch as ipex
|
||||
import intel_extension_for_pytorch._C as core
|
||||
import torch
|
||||
|
||||
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = (
|
||||
|
||||
Reference in New Issue
Block a user