Skip to content

Commit

Permalink
Add support for mps device
Browse files Browse the repository at this point in the history
  • Loading branch information
mnmly committed Apr 28, 2023
1 parent b0ed7e9 commit 56245ab
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
29 changes: 18 additions & 11 deletions scripts/python/sdpipeline/image_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@


def run(inference_steps, latent_dimension, input_embeddings, controlnet_geo, attention_slicing, guidance_scale, input_scheduler, torch_device, model="CompVis/stable-diffusion-v1-4", local_cache_only=True):

no_half = 'mps' == torch_device
dtype_unet = torch.float32 if no_half else torch.float16
dtype_controlnet = numpy.float32 if no_half else numpy.float16
scheduler_config = input_scheduler["config"]
t_start = scheduler_config["init_timesteps"]

Expand All @@ -80,7 +84,7 @@ def run(inference_steps, latent_dimension, input_embeddings, controlnet_geo, att
scheduler.set_timesteps(inference_steps)
timesteps = scheduler.timesteps

unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", local_files_only=local_cache_only, torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", local_files_only=local_cache_only, torch_dtype=dtype_unet)
unet.to(torch_device)

if attention_slicing:
Expand All @@ -96,7 +100,9 @@ def run(inference_steps, latent_dimension, input_embeddings, controlnet_geo, att

text_embeddings = numpy.array(input_embeddings["conditional_embedding"]).reshape(input_embeddings["tensor_shape"])
uncond_embeddings = numpy.array(input_embeddings["unconditional_embedding"]).reshape(input_embeddings["tensor_shape"])
text_embeddings = torch.from_numpy(numpy.array([uncond_embeddings, text_embeddings])).to(torch_device).half()
text_embeddings = torch.from_numpy(numpy.array([uncond_embeddings, text_embeddings])).to(torch_device)
if not no_half:
text_embeddings = text_embeddings.half()

if controlnet_geo:
controlnet_model = []
Expand All @@ -108,14 +114,14 @@ def run(inference_steps, latent_dimension, input_embeddings, controlnet_geo, att
width = int(geo.attribValue("image_dimension")[0])
height = int(geo.attribValue("image_dimension")[1])
controlnet_conditioning_scale = point.attribValue("scale")
r = numpy.array(geo.pointFloatAttribValues("r"), dtype=numpy.float16).reshape(width, height)
g = numpy.array(geo.pointFloatAttribValues("g"), dtype=numpy.float16).reshape(width, height)
b = numpy.array(geo.pointFloatAttribValues("b"), dtype=numpy.float16).reshape(width, height)
r = numpy.array(geo.pointFloatAttribValues("r"), dtype=dtype_controlnet).reshape(width, height)
g = numpy.array(geo.pointFloatAttribValues("g"), dtype=dtype_controlnet).reshape(width, height)
b = numpy.array(geo.pointFloatAttribValues("b"), dtype=dtype_controlnet).reshape(width, height)
input_colors = numpy.flip(numpy.stack((r, g, b), axis=0),2)

controlnet_conditioning_image = torch.from_numpy(numpy.array([input_colors])).to(device=torch_device)
controlnet_conditioning_image = controlnet_conditioning_image.to(torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnetmodel, local_files_only=local_cache_only, torch_dtype=torch.float16)
controlnet_conditioning_image = controlnet_conditioning_image.to(dtype_unet)
controlnet = ControlNetModel.from_pretrained(controlnetmodel, local_files_only=local_cache_only, torch_dtype=dtype_unet)
controlnet.to(torch_device)
controlnet_model.append(controlnet)
controlnet_image.append(controlnet_conditioning_image)
Expand All @@ -127,18 +133,19 @@ def run(inference_steps, latent_dimension, input_embeddings, controlnet_geo, att
latents = init_latents

with hou.InterruptableOperation("Solving Stable Diffusion", open_interrupt_dialog=True) as operation:
# if True:
for i, t in enumerate(tqdm(timesteps, disable=True)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t).to(torch.float16)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t).to(dtype_unet)

if controlnet_geo:
down_block_res_samples, mid_block_res_sample = controlnet(latent_model_input, t.to(torch.float16), encoder_hidden_states=text_embeddings, controlnet_cond=controlnet_image, conditioning_scale=controlnet_scale, return_dict=False,)
down_block_res_samples, mid_block_res_sample = controlnet(latent_model_input, t.to(dtype_unet), encoder_hidden_states=text_embeddings, controlnet_cond=controlnet_image, conditioning_scale=controlnet_scale, return_dict=False,)

with torch.no_grad():
noise_pred = unet(latent_model_input, t.to(torch.float16), encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample
noise_pred = unet(latent_model_input, t.to(dtype_unet), encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample
else:
with torch.no_grad():
noise_pred = unet(latent_model_input, t.to(torch.float16), encoder_hidden_states=text_embeddings).sample
noise_pred = unet(latent_model_input, t.to(dtype_unet), encoder_hidden_states=text_embeddings).sample

# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down
14 changes: 12 additions & 2 deletions scripts/python/sdpipeline/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ def run(input_latents, latent_dimension, image_latents, guiding_strength, infere
print(f"Unexpected {err}, {type(err)}")

scheduler_object.set_timesteps(inference_steps)
noise_latents = torch.from_numpy(numpy.array([input_latents.reshape(4, latent_dimension[0], latent_dimension[1])])).to(torch_device)
noise_latents = torch.from_numpy(numpy.array([input_latents.reshape(4, latent_dimension[0], latent_dimension[1])]))

# If torch_device is `mps``, make sure dtype is set to float32
# as currently MPS cannot handle float64
if torch_device == 'mps':
noise_latents = noise_latents.to(torch.float32)
noise_latents = noise_latents.to(torch_device)

scheduler = {}
scheduler["guided_latents"] = noise_latents.cpu().numpy()[0]
Expand All @@ -28,7 +34,11 @@ def run(input_latents, latent_dimension, image_latents, guiding_strength, infere

if len(image_latents) != 0:
if guiding_strength > 0.05 and guiding_strength < 1.0:
image_latents = torch.from_numpy(numpy.array([image_latents.reshape(4, latent_dimension[0], latent_dimension[1])])).to(torch_device)
image_latents = torch.from_numpy(numpy.array([image_latents.reshape(4, latent_dimension[0], latent_dimension[1])]))
# for mps device, make sure dtype is set to float32
if torch_device == 'mps':
image_latents = image_latents.to(torch.float32)
image_latents = image_latents.to(torch_device)
guided_latents = scheduler_object.add_noise(image_latents, noise_latents, timesteps)
scheduler["guided_latents"] = guided_latents.cpu().numpy()[0]
t_start = max(inference_steps - init_timestep, 0)
Expand Down

0 comments on commit 56245ab

Please sign in to comment.