feat: add base model recognition for ip adapter safetensor files

This commit is contained in:
blessedcoolant 2024-03-24 01:58:46 +05:30
parent b013d0e064
commit 60bf0caca3
2 changed files with 34 additions and 40 deletions

View File

@ -14,12 +14,10 @@ from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (AttnProcessor2_0,
AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor)
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -28,26 +26,17 @@ from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import (LATENT_SCALE_FACTOR,
from invokeai.app.invocations.fields import ( SCHEDULER_NAME_VALUES)
ConditioningField, from invokeai.app.invocations.fields import (ConditioningField,
DenoiseMaskField, DenoiseMaskField,
FieldDescriptions, FieldDescriptions, ImageField,
ImageField, Input, InputField, LatentsField,
Input, OutputField, UIType, WithBoard,
InputField, WithMetadata)
LatentsField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (DenoiseMaskOutput,
DenoiseMaskOutput, ImageOutput, LatentsOutput)
ImageOutput,
LatentsOutput,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
@ -55,25 +44,19 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import (PipelineIntermediateState,
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo set_seamless)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData, IPAdapterConditioningInfo)
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData, IPAdapterData, StableDiffusionGeneratorPipeline,
IPAdapterData, T2IAdapterData, image_resized_to_grid_as_tensor)
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import ( from .baseinvocation import (BaseInvocation, BaseInvocationOutput, invocation,
BaseInvocation, invocation_output)
BaseInvocationOutput,
invocation,
invocation_output,
)
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField from .model import ModelIdentifierField, UNetField, VAEField

View File

@ -535,7 +535,18 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
for key in checkpoint.keys(): for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")): if not key.startswith(("image_proj.", "ip_adapter.")):
continue continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
print(cross_attention_dim)
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")