From 5922cee541cd4235ebbe782012aa4c11966f4009 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 12:26:51 -0500 Subject: [PATCH] Allow TIs to be either a key or a name in the prompt during our transition to using keys --- invokeai/app/invocations/compel.py | 72 +++++++++++++++++++----------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 50f5322513..ce9b1948eb 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,7 +3,7 @@ from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( @@ -18,7 +18,7 @@ from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import ModelType +from invokeai.backend.model_manager.config import ModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -70,7 +70,11 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + tokenizer_model = tokenizer_info.model + assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + text_encoder_model = text_encoder_info.model + assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -82,21 +86,29 @@ class CompelInvocation(BaseInvocation): # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list = [] + ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] + name_or_key = trigger[1:-1] try: - loaded_model = context.models.load(key=name).model - assert isinstance(loaded_model, TextualInversionModelRaw) - ti_list.append((name, loaded_model)) + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) except UnknownModelException: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + try: + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) + except UnknownModelException: + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( tokenizer, ti_manager, ), @@ -106,6 +118,7 @@ class CompelInvocation(BaseInvocation): # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): + assert isinstance(text_encoder, CLIPTextModel) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, @@ -155,7 +168,11 @@ class SDXLPromptInvocationBase: zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + tokenizer_model = tokenizer_info.model + assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + text_encoder_model = text_encoder_info.model + assert isinstance(text_encoder_model, CLIPTextModel) # return zero on empty if prompt == "" and zero_on_empty: @@ -189,25 +206,29 @@ class SDXLPromptInvocationBase: # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list = [] + ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] for trigger in extract_ti_triggers_from_prompt(prompt): - name = trigger[1:-1] + name_or_key = trigger[1:-1] try: - ti_model = context.models.load_by_attrs( - model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ).model - assert isinstance(ti_model, TextualInversionModelRaw) - ti_list.append((name, ti_model)) + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) except UnknownModelException: - # print(e) - # import traceback - # print(traceback.format_exc()) - logger.warning(f'trigger: "{trigger}" not found') + try: + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) + except UnknownModelException: + logger.warning(f'trigger: "{trigger}" not found') except ValueError: logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( tokenizer, ti_manager, ), @@ -215,8 +236,9 @@ class SDXLPromptInvocationBase: # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): + assert isinstance(text_encoder, CLIPTextModel) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder,