fix: Assertion issue with SDXL Compel

This commit is contained in:
blessedcoolant 2024-02-29 11:32:28 +05:30 committed by psychedelicious
parent 01898d766f
commit ae34bcfbc0
2 changed files with 11 additions and 21 deletions

View File

@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union from typing import Iterator, List, Optional, Tuple, Union, cast
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import generate_ti_list
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import ( from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField from .model import ClipField
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
assert isinstance(tokenizer_model, CLIPTokenizer) assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel) assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
): ):
assert isinstance(text_encoder, CLIPTextModel) assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
@ -168,7 +168,7 @@ class ModelPatcher:
def apply_ti( def apply_ti(
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]], ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None init_tokens_count = None
@ -265,7 +265,7 @@ class ModelPatcher:
@contextmanager @contextmanager
def apply_clip_skip( def apply_clip_skip(
cls, cls,
text_encoder: CLIPTextModel, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int, clip_skip: int,
) -> None: ) -> None:
skipped_layers = [] skipped_layers = []