mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Ti trigger from prompt util (#5294)
* Pull logic for extracting TI triggers into a util function * Remove duplicate regex for ti triggers * Fix linting for ruff * Remove unused imports
This commit is contained in:
parent
2d11d97dad
commit
32ad742f3e
@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
|
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
|
|||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
|
||||||
|
|
||||||
# from contextlib import ExitStack
|
# from contextlib import ExitStack
|
||||||
from typing import List, Literal, Union
|
from typing import List, Literal, Union
|
||||||
@ -21,6 +20,7 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType
|
|||||||
from ...backend.model_management import ONNXModelPatcher
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util import choose_torch_device
|
from ...backend.util import choose_torch_device
|
||||||
|
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -78,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
]
|
]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
|
8
invokeai/app/util/ti_utils.py
Normal file
8
invokeai/app/util/ti_utils.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
||||||
|
ti_triggers = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
|
ti_triggers.append(trigger)
|
||||||
|
return ti_triggers
|
Loading…
Reference in New Issue
Block a user