mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Extract TI loading logic into util, disallow it from ever failing a generation
This commit is contained in:
parent
9a1e55a305
commit
2a6722bb6c
@ -16,7 +16,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import ModelType
|
from invokeai.backend.model_manager.config import ModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
@ -86,26 +86,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
|
||||||
name_or_key = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
loaded_model = context.models.load(key=name_or_key)
|
|
||||||
model = loaded_model.model
|
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name_or_key, model))
|
|
||||||
except UnknownModelException:
|
|
||||||
try:
|
|
||||||
loaded_model = context.models.load_by_attrs(
|
|
||||||
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
|
||||||
)
|
|
||||||
model = loaded_model.model
|
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name_or_key, model))
|
|
||||||
except UnknownModelException:
|
|
||||||
logger.warning(f'trigger: "{trigger}" not found')
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
@ -206,26 +187,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
|
||||||
name_or_key = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
loaded_model = context.models.load(key=name_or_key)
|
|
||||||
model = loaded_model.model
|
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name_or_key, model))
|
|
||||||
except UnknownModelException:
|
|
||||||
try:
|
|
||||||
loaded_model = context.models.load_by_attrs(
|
|
||||||
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
|
||||||
)
|
|
||||||
model = loaded_model.model
|
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name_or_key, model))
|
|
||||||
except UnknownModelException:
|
|
||||||
logger.warning(f'trigger: "{trigger}" not found')
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
|
@ -1,8 +1,44 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
|
||||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
|
||||||
ti_triggers = []
|
ti_triggers: List[str] = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
ti_triggers.append(trigger)
|
ti_triggers.append(str(trigger))
|
||||||
return ti_triggers
|
return ti_triggers
|
||||||
|
|
||||||
|
def generate_ti_list(prompt: str, base: BaseModelType, context: InvocationContext) -> List[Tuple[str, TextualInversionModelRaw]]:
|
||||||
|
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||||
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
|
name_or_key = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
assert loaded_model.config.base == base
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
try:
|
||||||
|
loaded_model = context.models.load_by_attrs(
|
||||||
|
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
|
||||||
|
)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
assert loaded_model.config.base == base
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||||
|
except AssertionError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
|
||||||
|
return ti_list
|
Loading…
Reference in New Issue
Block a user