Next: Switch SDXLPromptInvocationBase to read TI names as model keys rather than model_name

This commit is contained in:
Brandon Rising 2024-02-27 10:15:39 -05:00 committed by Brandon
parent 4418c118db
commit 33856def7c
2 changed files with 8 additions and 8 deletions

View File

@ -193,11 +193,9 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.models.load_by_attrs(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
loaded_model = context.models.load(key=name).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback

View File

@ -1,8 +1,10 @@
import re
from typing import List
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
ti_triggers: List[str] = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(trigger)
ti_triggers.append(str(trigger))
return ti_triggers