diff --git a/docs/source/tutorials/stable_diffusion.mdx b/docs/source/tutorials/stable_diffusion.mdx
index 648fde3ff..408924cd1 100644
--- a/docs/source/tutorials/stable_diffusion.mdx
+++ b/docs/source/tutorials/stable_diffusion.mdx
@@ -173,6 +173,40 @@ image.save("cat_on_bench.png")
:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|
| | ***Face of a yellow cat, high resolution, sitting on a park bench*** | |
+### InstructPix2Pix
+
+With the `NeuronStableDiffusionInstructPix2PixPipeline` class, you can apply instruction-based image editing using both text guidance and image guidance.
+
+```python
+import requests
+import PIL
+from io import BytesIO
+from optimum.neuron import NeuronStableDiffusionInstructPix2PixPipeline
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+model_id = "timbrooks/instruct-pix2pix"
+input_shapes = {"batch_size": 1, "height": 512, "width": 512}
+
+pipe = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained(
+ model_id, export=True, dynamic_batch_size=True, **input_shapes,
+)
+pipe.save_pretrained("sd_ip2p/")
+
+img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
+init_image = download_image(img_url).resize((512, 512))
+
+prompt = "Add a beautiful sunset"
+image = pipe(prompt=prompt, image=init_image).images[0]
+image.save("sunset_mountain.png")
+```
+
+`image` | `prompt` | output |
+:-------------------------:|:-------------------------:|-------------------------:|
+ | ***Add a beautiful sunset*** | |
+
## Stable Diffusion XL
*There is a notebook version of that tutorial [here](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/stable-diffusion/stable-diffusion-xl-txt2img.ipynb)*.
diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py
index b48426e43..8cd1328b1 100644
--- a/optimum/neuron/__init__.py
+++ b/optimum/neuron/__init__.py
@@ -48,6 +48,7 @@
"NeuronStableDiffusionPipeline",
"NeuronStableDiffusionImg2ImgPipeline",
"NeuronStableDiffusionInpaintPipeline",
+ "NeuronStableDiffusionInstructPix2PixPipeline",
"NeuronLatentConsistencyModelPipeline",
"NeuronStableDiffusionXLPipeline",
"NeuronStableDiffusionXLImg2ImgPipeline",
@@ -88,6 +89,7 @@
NeuronStableDiffusionControlNetPipeline,
NeuronStableDiffusionImg2ImgPipeline,
NeuronStableDiffusionInpaintPipeline,
+ NeuronStableDiffusionInstructPix2PixPipeline,
NeuronStableDiffusionPipeline,
NeuronStableDiffusionPipelineBase,
NeuronStableDiffusionXLImg2ImgPipeline,
diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py
index c13ded908..4e3e193c1 100644
--- a/optimum/neuron/modeling_diffusion.py
+++ b/optimum/neuron/modeling_diffusion.py
@@ -94,6 +94,7 @@
NeuronStableDiffusionControlNetPipelineMixin,
NeuronStableDiffusionImg2ImgPipelineMixin,
NeuronStableDiffusionInpaintPipelineMixin,
+ NeuronStableDiffusionInstructPix2PixPipelineMixin,
NeuronStableDiffusionPipelineMixin,
NeuronStableDiffusionXLControlNetPipelineMixin,
NeuronStableDiffusionXLImg2ImgPipelineMixin,
@@ -1222,6 +1223,12 @@ class NeuronStableDiffusionInpaintPipeline(
__call__ = NeuronStableDiffusionInpaintPipelineMixin.__call__
+class NeuronStableDiffusionInstructPix2PixPipeline(
+ NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInstructPix2PixPipelineMixin
+):
+ __call__ = NeuronStableDiffusionInstructPix2PixPipelineMixin.__call__
+
+
class NeuronLatentConsistencyModelPipeline(NeuronStableDiffusionPipelineBase, NeuronLatentConsistencyPipelineMixin):
__call__ = NeuronLatentConsistencyPipelineMixin.__call__
diff --git a/optimum/neuron/pipelines/__init__.py b/optimum/neuron/pipelines/__init__.py
index aa5366fc8..d4da684ab 100644
--- a/optimum/neuron/pipelines/__init__.py
+++ b/optimum/neuron/pipelines/__init__.py
@@ -24,6 +24,7 @@
"NeuronStableDiffusionPipelineMixin",
"NeuronStableDiffusionImg2ImgPipelineMixin",
"NeuronStableDiffusionInpaintPipelineMixin",
+ "NeuronStableDiffusionInstructPix2PixPipelineMixin",
"NeuronLatentConsistencyPipelineMixin",
"NeuronStableDiffusionControlNetPipelineMixin",
"NeuronStableDiffusionXLPipelineMixin",
@@ -39,6 +40,7 @@
NeuronStableDiffusionControlNetPipelineMixin,
NeuronStableDiffusionImg2ImgPipelineMixin,
NeuronStableDiffusionInpaintPipelineMixin,
+ NeuronStableDiffusionInstructPix2PixPipelineMixin,
NeuronStableDiffusionPipelineMixin,
NeuronStableDiffusionXLControlNetPipelineMixin,
NeuronStableDiffusionXLImg2ImgPipelineMixin,
diff --git a/optimum/neuron/pipelines/diffusers/__init__.py b/optimum/neuron/pipelines/diffusers/__init__.py
index b39843657..eefe130a0 100644
--- a/optimum/neuron/pipelines/diffusers/__init__.py
+++ b/optimum/neuron/pipelines/diffusers/__init__.py
@@ -19,6 +19,7 @@
from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin
from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin
from .pipeline_stable_diffusion_inpaint import NeuronStableDiffusionInpaintPipelineMixin
+from .pipeline_stable_diffusion_instruct_pix2pix import NeuronStableDiffusionInstructPix2PixPipelineMixin
from .pipeline_stable_diffusion_xl import NeuronStableDiffusionXLPipelineMixin
from .pipeline_stable_diffusion_xl_img2img import NeuronStableDiffusionXLImg2ImgPipelineMixin
from .pipeline_stable_diffusion_xl_inpaint import NeuronStableDiffusionXLInpaintPipelineMixin
diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py
new file mode 100644
index 000000000..9eb659aea
--- /dev/null
+++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -0,0 +1,475 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Override some diffusers API for NeuronStableDiffusionInstructPix2PixPipeline"""
+
+import logging
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
+
+import PIL
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+from diffusers.loaders import TextualInversionLoaderMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils.deprecation_utils import deprecate
+
+from .pipeline_utils import StableDiffusionPipelineMixin
+
+
+if TYPE_CHECKING:
+ from diffusers.image_processor import PipelineImageInput
+
+
+logger = logging.getLogger(__name__)
+
+
+class NeuronStableDiffusionInstructPix2PixPipelineMixin(
+ StableDiffusionPipelineMixin, StableDiffusionInstructPix2PixPipeline
+):
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ image: Optional["PipelineImageInput"] = None,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.5,
+ image_guidance_scale: float = 1.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`Optional[Union[str, List[str]]]`, defaults to `None`):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`Optional["PipelineImageInput"]`, defaults to `None`):
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
+ image latents as `image`, but if passing latents directly it is not encoded again.
+ num_inference_steps (`int`, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ image_guidance_scale (`float`, defaults to 1.5):
+ Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
+ value of at least `1`.
+ negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`Optional[torch.FloatTensor]`, defaults to `None`):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback_on_step_end (`Optional[Callable[[int, int, Dict], None]]`, defaults to `None`):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
+
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> from io import BytesIO
+
+ >>> from optimum.neuron import NeuronStableDiffusionInstructPix2PixPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
+ >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512}
+ >>> pipe = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained(
+ ... "timbrooks/instruct-pix2pix", export=True, dynamic_batch_size=True, **compiler_args, **input_shapes,
+ ... )
+ >>> pipe.save_pretrained("sd_ip2p/")
+
+ >>> prompt = "Add a beautiful sunset"
+ >>> image = pipe(prompt=prompt, image=init_image).images[0]
+ ```
+
+ Returns:
+ [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ # 0. Check inputs
+ self.check_inputs(
+ prompt,
+ None,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._image_guidance_scale = image_guidance_scale
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ neuron_batch_size = self.unet.config.neuron["static_batch_size"]
+ self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt)
+
+ # 2. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 3. Preprocess image
+ height = self.vae_encoder.config.neuron["static_height"]
+ width = self.vae_encoder.config.neuron["static_width"]
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare Image latents
+ image_latents = self.prepare_image_latents(
+ image,
+ batch_size,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ generator,
+ )
+
+ height, width = image_latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae_decoder.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ generator,
+ latents,
+ )
+
+ # 7. Check that shapes of latents and image match the UNet channels
+ num_channels_image = image_latents.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance.
+ # The latents are expanded 3 times because for pix2pix the guidance\
+ # is applied for both the text and the input image.
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
+
+ # concat latents, image_latents in the channel dimension
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ scaled_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ image_latents = callback_outputs.pop("image_latents", image_latents)
+
+ if not output_type == "latent":
+ image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
+ def prepare_image_latents(
+ self, image, batch_size, num_images_per_prompt, do_classifier_free_guidance, generator=None
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ image_latents = [self.vae_encoder(sample=image[i : i + 1])[0] for i in range(batch_size)]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae_encoder(sample=image)[0]
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_image_latents = torch.zeros_like(image_latents)
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
+
+ return image_latents
+
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionInstructPix2PixPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int,
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[Union[str, List]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids)
+ prompt_embeds = prompt_embeds[0]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids)
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
+
+ return prompt_embeds
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 and self.dynamic_batch_size
diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py
index f3a0bfda1..b37e434b9 100644
--- a/tests/inference/inference_utils.py
+++ b/tests/inference/inference_utils.py
@@ -58,6 +58,7 @@
"swin": "hf-internal-testing/tiny-random-SwinModel",
"vit": "hf-internal-testing/tiny-random-vit",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
+ "stable-diffusion-ip2p": "asntr/tiny-stable-diffusion-pix2pix-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm-roberta": "hf-internal-testing/tiny-xlm-roberta",
diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py
index 3006be23b..3d73ed38d 100644
--- a/tests/inference/test_stable_diffusion_pipeline.py
+++ b/tests/inference/test_stable_diffusion_pipeline.py
@@ -29,6 +29,7 @@
NeuronStableDiffusionControlNetPipeline,
NeuronStableDiffusionImg2ImgPipeline,
NeuronStableDiffusionInpaintPipeline,
+ NeuronStableDiffusionInstructPix2PixPipeline,
NeuronStableDiffusionPipeline,
NeuronStableDiffusionXLImg2ImgPipeline,
NeuronStableDiffusionXLInpaintPipeline,
@@ -132,6 +133,22 @@ def test_inpaint_export_and_inference(self, model_arch):
image = neuron_pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
self.assertIsInstance(image, PIL.Image.Image)
+ @parameterized.expand(["stable-diffusion-ip2p"], skip_on_empty=True)
+ def test_instruct_pix2pix_export_and_inference(self, model_arch):
+ neuron_pipeline = NeuronStableDiffusionInstructPix2PixPipeline.from_pretrained(
+ MODEL_NAMES[model_arch],
+ export=True,
+ dynamic_batch_size=True,
+ **self.STATIC_INPUTS_SHAPES,
+ **self.COMPILER_ARGS,
+ )
+
+ img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
+ init_image = download_image(img_url).resize((512, 512))
+ prompt = "Add a beautiful sunset"
+ image = neuron_pipeline(prompt=prompt, image=init_image).images[0]
+ self.assertIsInstance(image, PIL.Image.Image)
+
@parameterized.expand(["latent-consistency"], skip_on_empty=True)
def test_lcm_export_and_inference(self, model_arch):
neuron_pipeline = NeuronLatentConsistencyModelPipeline.from_pretrained(