diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bc79efdeba..8ad1684bcb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -14,12 +14,10 @@ from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) +from diffusers.models.attention_processor import (AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor) from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler @@ -28,26 +26,17 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPVisionModelWithProjection -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES -from invokeai.app.invocations.fields import ( - ConditioningField, - DenoiseMaskField, - FieldDescriptions, - ImageField, - Input, - InputField, - LatentsField, - OutputField, - UIType, - WithBoard, - WithMetadata, -) +from invokeai.app.invocations.constants import (LATENT_SCALE_FACTOR, + SCHEDULER_NAME_VALUES) +from invokeai.app.invocations.fields import (ConditioningField, + DenoiseMaskField, + FieldDescriptions, ImageField, + Input, InputField, LatentsField, + OutputField, UIType, WithBoard, + WithMetadata) from invokeai.app.invocations.ip_adapter import IPAdapterField -from invokeai.app.invocations.primitives import ( - DenoiseMaskOutput, - ImageOutput, - LatentsOutput, -) +from invokeai.app.invocations.primitives import (DenoiseMaskOutput, + ImageOutput, LatentsOutput) from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image @@ -55,25 +44,19 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion import (PipelineIntermediateState, + set_seamless) +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + ConditioningData, IPAdapterConditioningInfo) from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( - ControlNetData, - IPAdapterData, - StableDiffusionGeneratorPipeline, - T2IAdapterData, - image_resized_to_grid_as_tensor, -) + ControlNetData, IPAdapterData, StableDiffusionGeneratorPipeline, + T2IAdapterData, image_resized_to_grid_as_tensor) from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_precision, choose_torch_device -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - invocation, - invocation_output, -) +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, invocation, + invocation_output) from .controlnet_image_processors import ControlField from .model import ModelIdentifierField, UNetField, VAEField diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ed73fc56c6..bd47cc1a48 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -535,7 +535,18 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase): for key in checkpoint.keys(): if not key.startswith(("image_proj.", "ip_adapter.")): continue - return BaseModelType.StableDiffusionXL + cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] + print(cross_attention_dim) + if cross_attention_dim == 768: + return BaseModelType.StableDiffusion1 + elif cross_attention_dim == 1024: + return BaseModelType.StableDiffusion2 + elif cross_attention_dim == 2048: + return BaseModelType.StableDiffusionXL + else: + raise InvalidModelConfigException( + f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." + ) raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")