From 32ad742f3e9e6aad41665213794514dd39c9a560 Mon Sep 17 00:00:00 2001 From: Brandon <58442074+brandonrising@users.noreply.github.com> Date: Thu, 21 Dec 2023 22:04:44 -0500 Subject: [PATCH] Ti trigger from prompt util (#5294) * Pull logic for extracting TI triggers into a util function * Remove duplicate regex for ti triggers * Fix linting for ruff * Remove unused imports --- invokeai/app/invocations/compel.py | 6 +++--- invokeai/app/invocations/onnx.py | 4 ++-- invokeai/app/util/ti_utils.py | 8 ++++++++ 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 invokeai/app/util/ti_utils.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5494c3261f..49c62cff56 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,4 +1,3 @@ -import re from dataclasses import dataclass from typing import List, Optional, Union @@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.models import ModelNotFoundException, ModelType from ...backend.util.devices import torch_dtype +from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation): # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: ti_list.append( @@ -210,7 +210,7 @@ class SDXLPromptInvocationBase: # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): + for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: ti_list.append( diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 9eca5e083e..759cfde700 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -1,7 +1,6 @@ # Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) import inspect -import re # from contextlib import ExitStack from typing import List, Literal, Union @@ -21,6 +20,7 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType from ...backend.model_management import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util import choose_torch_device +from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -78,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation): ] ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: ti_list.append( diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py new file mode 100644 index 0000000000..a66a832b42 --- /dev/null +++ b/invokeai/app/util/ti_utils.py @@ -0,0 +1,8 @@ +import re + + +def extract_ti_triggers_from_prompt(prompt: str) -> list[str]: + ti_triggers = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): + ti_triggers.append(trigger) + return ti_triggers