diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index d88313f455..e9416c8057 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -34,6 +34,7 @@ from .diffusion import ( BasicConditioningInfo, ) from ..util import normalize_device, auto_detect_slice_size +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter @dataclass @@ -357,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, + ip_adapter_image: Optional[PIL.Image] = None, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, seed: Optional[int] = None, @@ -408,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data, additional_guidance=additional_guidance, control_data=control_data, + ip_adapter_image=ip_adapter_image, callback=callback, ) finally: @@ -427,8 +430,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, + ip_adapter_image: Optional[PIL.Image] = None, callback: Callable[[PipelineIntermediateState], None] = None, ): + self._adjust_memory_efficient_attention(latents) if additional_guidance is None: additional_guidance = [] @@ -439,6 +444,55 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver + print("ip_adapter_image: ", type(ip_adapter_image)) + if ip_adapter_image is not None: + # initialize IPAdapter + print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height) + clip_image_encoder_path = "ip_adapter_models_sd_1.5/image_encoder/" + ip_adapter_model_path = "ip_adapter_models_sd_1.5/ip-adapter_sd15.bin" + # FIXME: + # WARNING! + # IPAdapter constructor modifies UNet model in-place + # Adds additional cross-attention layers to UNet model for image embedding + # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter + # and how to undo if ip_adapter_image is removed + # use existing model management context etc? + # + ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline + clip_image_encoder_path, # hardwiring to manually downloaded dir for first pass + ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass + "cuda") # hardwiring CUDA GPU for first pass + # IP-Adapter ==> add additional cross-attention layers to UNet model here? + print("ip_adapter:", ip_adapter) + + # get image embedding from CLIP and ImageProjModel + print("getting image embeddings from IP-Adapter...") + num_samples = 1 # hardwiring for first pass + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image) + print("image cond embeds shape:", image_prompt_embeds.shape) + print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + print("image cond embeds shape:", image_prompt_embeds.shape) + print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) + + # IP-Adapter: run IP-Adapter model here? + # and add output as additional cross-attention layers + text_prompt_embeds = conditioning_data.text_embeddings.embeds + uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds + print("text embeds shape:", text_prompt_embeds.shape) + concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1) + concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1) + print("concat embeds shape:", concat_prompt_embeds.shape) + conditioning_data.text_embeddings.embeds = concat_prompt_embeds + conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds + else: + image_prompt_embeds = None + uncond_image_prompt_embeds = None + extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model,