mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Assertion issue with SDXL Compel
This commit is contained in:
parent
01898d766f
commit
ae34bcfbc0
@ -1,17 +1,11 @@
|
|||||||
from typing import Iterator, List, Optional, Tuple, Union
|
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@ -25,12 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
# unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
@ -149,7 +138,7 @@ class SDXLPromptInvocationBase:
|
|||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
assert isinstance(text_encoder, CLIPTextModel)
|
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -4,12 +4,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
@ -168,7 +168,7 @@ class ModelPatcher:
|
|||||||
def apply_ti(
|
def apply_ti(
|
||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
@ -265,7 +265,7 @@ class ModelPatcher:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_clip_skip(
|
def apply_clip_skip(
|
||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
skipped_layers = []
|
skipped_layers = []
|
||||||
|
Loading…
Reference in New Issue
Block a user