fix: Assertion issue with SDXL Compel

This commit is contained in:
blessedcoolant
2024-02-29 11:32:28 +05:30
committed by psychedelicious
parent 01898d766f
commit ae34bcfbc0
2 changed files with 11 additions and 21 deletions

View File

@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,