Extract TI loading logic into util, disallow it from ever failing a generation

This commit is contained in:
Brandon Rising
2024-02-27 15:20:14 -05:00
committed by Brandon
parent 408a800593
commit e4b8cb1d34
2 changed files with 42 additions and 44 deletions

View File

@ -16,7 +16,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import ModelType
from invokeai.backend.model_patcher import ModelPatcher
@ -86,26 +86,7 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
@ -206,26 +187,7 @@ class SDXLPromptInvocationBase:
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
ti_list.append((name_or_key, model))
except UnknownModelException:
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (