diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 81fd1f9f5d..8902152538 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -36,7 +36,7 @@ class CLIPVisionModelField(BaseModel): class IPAdapterField(BaseModel): - image: ImageField = Field(description="The IP-Adapter image prompt.") + image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.") image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") @@ -55,12 +55,12 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.0") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" # Inputs - image: ImageField = InputField(description="The IP-Adapter image prompt.") + image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") ip_adapter_model: IPAdapterModelField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1 ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7ce0ae7a8a..c28c87395d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -214,7 +214,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.3.0", + version="1.4.0", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -491,16 +491,21 @@ class DenoiseLatentsInvocation(BaseInvocation): context=context, ) - input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name) + # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. + single_ipa_images = single_ip_adapter.image + if not isinstance(single_ipa_images, list): + single_ipa_images = [single_ipa_images] + + single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. with image_encoder_model_info as image_encoder_model: # Get image embeddings from CLIP and ImageProjModel. - ( - image_prompt_embeds, - uncond_image_prompt_embeds, - ) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model) + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( + single_ipa_images, image_encoder_model + ) + conditioning_data.ip_adapter_conditioning.append( IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds) ) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py index 2873c52322..c06d7d113c 100644 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ b/invokeai/backend/ip_adapter/attention_processor.py @@ -67,6 +67,12 @@ class IPAttnProcessor2_0(torch.nn.Module): temb=None, ip_adapter_image_prompt_embeds=None, ): + """Apply IP-Adapter attention. + + Args: + ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings. + Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len). + """ residual = hidden_states if attn.spatial_norm is not None: @@ -127,26 +133,35 @@ class IPAttnProcessor2_0(torch.nn.Module): for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales): # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The channel dimensions should match. - assert ipa_embed.shape[2] == encoder_hidden_states.shape[2] + # The token_len dimensions should match. + assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] ip_hidden_states = ipa_embed + # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) + ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states) + # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # The output of sdpa has shape: (batch, num_heads, seq_len, head_dim) + # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + scale * ip_hidden_states # linear proj diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 7c3f835a44..6a63c225fc 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -55,11 +55,11 @@ class PostprocessingSettings: class IPAdapterConditioningInfo: cond_image_prompt_embeds: torch.Tensor """IP-Adapter image encoder conditioning embeddings. - Shape: (batch_size, num_tokens, encoding_dim). + Shape: (num_images, num_tokens, encoding_dim). """ uncond_image_prompt_embeds: torch.Tensor """IP-Adapter image encoding embeddings to use for unconditional generation. - Shape: (batch_size, num_tokens, encoding_dim). + Shape: (num_images, num_tokens, encoding_dim). """ diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index c12c86ed92..943fe7b307 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -345,9 +345,12 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: + # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ - torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) + torch.stack( + [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] + ) for ipa_conditioning in conditioning_data.ip_adapter_conditioning ] } @@ -415,9 +418,10 @@ class InvokeAIDiffuserComponent: # Run unconditional UNet denoising. cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: + # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ - ipa_conditioning.uncond_image_prompt_embeds + torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) for ipa_conditioning in conditioning_data.ip_adapter_conditioning ] } @@ -444,9 +448,10 @@ class InvokeAIDiffuserComponent: # Run conditional UNet denoising. cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: + # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { "ip_adapter_image_prompt_embeds": [ - ipa_conditioning.cond_image_prompt_embeds + torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) for ipa_conditioning in conditioning_data.ip_adapter_conditioning ] } diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 7f634ee1fe..6712196778 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_adapter.to(torch_device, dtype=torch.float32) unet.to(torch_device, dtype=torch.float32) - cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]} + # ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len) + ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) + + cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample