fix: cleanup across various ip adapter files

This commit is contained in:
blessedcoolant 2024-03-24 02:27:38 +05:30
parent 60bf0caca3
commit 4ed2bf53ca
4 changed files with 40 additions and 39 deletions

View File

@ -91,7 +91,7 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ( image_encoder_model_id = (
ip_adapter_info.image_encoder_model_id ip_adapter_info.image_encoder_model_id
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig) if isinstance(ip_adapter_info, IPAdapterDiffusersConfig)
else "InvokeAI/ip_adapter_sd_image_encoder" else "ip_adapter_sd_image_encoder"
) )
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)

View File

@ -14,10 +14,12 @@ from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (AttnProcessor2_0, from diffusers.models.attention_processor import (
LoRAAttnProcessor2_0, AttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0,
XFormersAttnProcessor) LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -26,17 +28,22 @@ from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import (LATENT_SCALE_FACTOR, from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
SCHEDULER_NAME_VALUES) from invokeai.app.invocations.fields import (
from invokeai.app.invocations.fields import (ConditioningField, ConditioningField,
DenoiseMaskField, DenoiseMaskField,
FieldDescriptions, ImageField, FieldDescriptions,
Input, InputField, LatentsField, ImageField,
OutputField, UIType, WithBoard, Input,
WithMetadata) InputField,
LatentsField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import (DenoiseMaskOutput, from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
ImageOutput, LatentsOutput)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
@ -44,19 +51,20 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import (PipelineIntermediateState, from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
set_seamless) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData, IPAdapterConditioningInfo)
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, IPAdapterData, StableDiffusionGeneratorPipeline, ControlNetData,
T2IAdapterData, image_resized_to_grid_as_tensor) IPAdapterData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, invocation, from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
invocation_output)
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField from .model import ModelIdentifierField, UNetField, VAEField

View File

@ -9,23 +9,16 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.model_hash.model_hash import (HASHING_ALGORITHMS,
ModelHash)
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.util import SilenceWarnings
from .config import ( from .config import (AnyModelConfig, BaseModelType,
AnyModelConfig, ControlAdapterDefaultSettings,
BaseModelType, InvalidModelConfigException, MainModelDefaultSettings,
ControlAdapterDefaultSettings, ModelConfigFactory, ModelFormat, ModelRepoVariant,
InvalidModelConfigException, ModelSourceType, ModelType, ModelVariantType,
MainModelDefaultSettings, SchedulerPredictionType)
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from .util.model_util import lora_token_vector_length, read_checkpoint_meta from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str | int, Any] CkptType = Dict[str | int, Any]
@ -536,7 +529,6 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
if not key.startswith(("image_proj.", "ip_adapter.")): if not key.startswith(("image_proj.", "ip_adapter.")):
continue continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
print(cross_attention_dim)
if cross_attention_dim == 768: if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024: elif cross_attention_dim == 1024:

View File

@ -655,6 +655,7 @@
"install": "Install", "install": "Install",
"installAll": "Install All", "installAll": "Install All",
"installRepo": "Install Repo", "installRepo": "Install Repo",
"ipAdapters": "IP Adapters",
"load": "Load", "load": "Load",
"localOnly": "local only", "localOnly": "local only",
"manual": "Manual", "manual": "Manual",