diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a88eff0fcb..8dd90a18e6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -50,7 +50,6 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType @@ -672,6 +671,39 @@ class DenoiseLatentsInvocation(BaseInvocation): return controlnet_data + def prep_ip_adapter_image_prompts( + self, + context: InvocationContext, + ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" + # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. + if not isinstance(ip_adapter, list): + ip_adapter = [ip_adapter] + + if len(ip_adapter) == 0: + return [] + + image_prompts = [] + for single_ip_adapter in ip_adapter: + with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model: + image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) + # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] + + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] + with image_encoder_model_info as image_encoder_model: + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) + # Get image embeddings from CLIP and ImageProjModel. + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( + single_ipa_images, image_encoder_model + ) + image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds)) + + return image_prompts + def prep_ip_adapter_data( self, context: InvocationContext, @@ -680,6 +712,7 @@ class DenoiseLatentsInvocation(BaseInvocation): latent_height: int, latent_width: int, dtype: torch.dtype, + image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[list[IPAdapterData]]: """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings to the `conditioning_data` (in-place). @@ -696,26 +729,9 @@ class DenoiseLatentsInvocation(BaseInvocation): ip_adapter_data_list = [] for single_ip_adapter in ip_adapter: - ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.models.load(single_ip_adapter.ip_adapter_model) - ) + ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model)) - image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) - # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_image_fields = single_ip_adapter.image - if not isinstance(single_ipa_image_fields, list): - single_ipa_image_fields = [single_ipa_image_fields] - - single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] - - # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other - # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. - with image_encoder_model_info as image_encoder_model: - assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) - # Get image embeddings from CLIP and ImageProjModel. - image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( - single_ipa_images, image_encoder_model - ) + image_prompt_embeds, uncond_image_prompt_embeds = image_prompts.pop(0) mask = single_ip_adapter.mask if mask is not None: @@ -912,6 +928,8 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) + image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapter=self.ip_adapter) + # get the unet's config so that we can pass the base to dispatch_progress() unet_config = context.models.get_config(self.unet.unet.key) @@ -975,6 +993,7 @@ class DenoiseLatentsInvocation(BaseInvocation): latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, + image_prompts=image_prompts, ) num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(