Skip to content

Commit

Permalink
Merge pull request #161 from intel/gpu_precision_fix
Browse files Browse the repository at this point in the history
Fix SD1.5 FP16 for some ARL and MTL GPU SKUs.
  • Loading branch information
gblong1 authored Jan 11, 2025
2 parents 6c5284c + 23a1ce1 commit 0846f9e
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,13 +523,15 @@ def __init__(

self.set_dimensions()



def load_model(self, model, model_name, device):
if "NPU" in device:
with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
return self.core.import_model(f.read(), device)
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
if "GPU" in device:
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device, {'INFERENCE_PRECISION_HINT': 'f32'})
else:
return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)


def set_dimensions(self):
latent_shape = self.unet.input(self.unet_input_tensor_name).shape
Expand Down

0 comments on commit 0846f9e

Please sign in to comment.