diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 771c811eea..ff13658052 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,17 +1,11 @@ -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union, cast import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, -) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.util.devices import torch_dtype -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - invocation, - invocation_output, -) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import ClipField # unconditioned: Optional[torch.Tensor] @@ -149,7 +138,7 @@ class SDXLPromptInvocationBase: assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) + assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,7 +185,8 @@ class SDXLPromptInvocationBase: # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): - assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index bee8909c31..473a088308 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -4,12 +4,12 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch from diffusers import OnnxRuntimeModel, UNet2DConditionModel -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig from invokeai.backend.model_manager import AnyModel @@ -168,7 +168,7 @@ class ModelPatcher: def apply_ti( cls, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], ti_list: List[Tuple[str, TextualInversionModelRaw]], ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None @@ -265,7 +265,7 @@ class ModelPatcher: @contextmanager def apply_clip_skip( cls, - text_encoder: CLIPTextModel, + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], clip_skip: int, ) -> None: skipped_layers = []