diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 59897f5333..44cfd67d0c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -50,6 +50,7 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType @@ -674,22 +675,13 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_ip_adapter_image_prompts( self, context: InvocationContext, - ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], + ip_adapters: List[IPAdapterField], ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. - if ip_adapter is None: - return [] - - if not isinstance(ip_adapter, list): - ip_adapter = [ip_adapter] - - if len(ip_adapter) == 0: - return [] - image_prompts = [] - for single_ip_adapter in ip_adapter: + for single_ip_adapter in ip_adapters: with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model: + assert isinstance(ip_adapter_model, IPAdapter) image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image @@ -710,36 +702,23 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_ip_adapter_data( self, context: InvocationContext, - ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], + ip_adapters: List[IPAdapterField], + image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], exit_stack: ExitStack, latent_height: int, latent_width: int, dtype: torch.dtype, - image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> Optional[list[IPAdapterData]]: - """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings - to the `conditioning_data` (in-place). - """ - if ip_adapter is None: - return None - - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. - if not isinstance(ip_adapter, list): - ip_adapter = [ip_adapter] - - if len(ip_adapter) == 0: - return None - + ) -> Optional[List[IPAdapterData]]: + """If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data.""" ip_adapter_data_list = [] - assert len(ip_adapter) == len(image_prompts) - for single_ip_adapter in ip_adapter: + assert len(ip_adapters) == len(image_prompts) + for single_ip_adapter in ip_adapters: ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model)) image_prompt_embeds, uncond_image_prompt_embeds = image_prompts.pop(0) - mask = single_ip_adapter.mask - if mask is not None: - mask = context.tensors.load(mask.tensor_name) + mask_field = single_ip_adapter.mask + mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) ip_adapter_data_list.append( @@ -754,7 +733,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) ) - return ip_adapter_data_list + return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None def run_t2i_adapters( self, @@ -932,7 +911,15 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapter=self.ip_adapter) + ip_adapters: List[IPAdapterField] = [] + if self.ip_adapter is not None: + # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. + if isinstance(self.ip_adapter, list): + ip_adapters = self.ip_adapter + else: + ip_adapters = [self.ip_adapter] + + image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) # get the unet's config so that we can pass the base to dispatch_progress() unet_config = context.models.get_config(self.unet.unet.key) @@ -992,12 +979,12 @@ class DenoiseLatentsInvocation(BaseInvocation): ip_adapter_data = self.prep_ip_adapter_data( context=context, - ip_adapter=self.ip_adapter, + ip_adapters=ip_adapters, + image_prompts=image_prompts, exit_stack=exit_stack, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, - image_prompts=image_prompts, ) num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(