Ti trigger from prompt util (#5294)

* Pull logic for extracting TI triggers into a util function

* Remove duplicate regex for ti triggers

* Fix linting for ruff

* Remove unused imports
This commit is contained in:
Brandon
2023-12-21 22:04:44 -05:00
committed by GitHub
parent 2d11d97dad
commit 32ad742f3e
3 changed files with 13 additions and 5 deletions

View File

@ -1,4 +1,3 @@
import re
from dataclasses import dataclass
from typing import List, Optional, Union
@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(