InvokeAI/invokeai/app/invocations/strings.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

140 lines
5.5 KiB
Python
Raw Permalink Normal View History

Prompts from file support nodes (#3964) * New classes to support the PromptsFromFileInvocation Class - PromptPosNegOutput - PromptSplitNegInvocation - PromptJoinInvocation - PromptReplaceInvocation * - Added PromptsToFileInvocation, - PromptSplitNegInvocation - now counts the bracket depth so ensures it cout the numbr of open and close brackets match. - checks for escaped [ ] so ignores them if escaped e.g \[ - PromptReplaceInvocation - now has a user regex. and no regex in made caseinsesitive * Update prompt.py created class PromptsToFileInvocationOutput and use it in PromptsToFileInvocation instead of BaseInvocationOutput * Update prompt.py * Added schema_extra title and tags for PromptReplaceInvocation, PromptJoinInvocation, PromptSplitNegInvocation and PromptsToFileInvocation * Added PTFileds Collect and Expand * update to nodes v1 * added ui_type to file_path for PromptToFile * update params for the primitive types used, remove the ui_type filepath, promptsToFile now only accepts collections until a fix is available * updated the parameters for the StringOutput primitive * moved the prompt tools nodes out of the prompt.py into prompt_tools.py * more rework for v1 * added github link * updated to use "@invocation" * updated tags * Adde new nodes PromptStrength and PromptStrengthsCombine * chore: black * feat(nodes): add version to prompt nodes * renamed nodes from prompt related to string related. Also moved them into a strings.py file. Also moved and renamed the PromptsFromFileInvocation from prompt.py to strings.py. The PTfileds still remain in the Prompt_tool.py for now. * added , version="1.0.0" to the invocations * removed the PTField related nodes and the prompt-tools.py file all new nodes now live in the * formatted prompt.py and strings.py with Black and fixed silly mistake in the new StringSplitInvocation * - Revert Prompt.py back to original - Update strings.py to be only StringJoin, StringJoinThre, StringReplace, StringSplitNeg, StringSplit * applied isort to imports * fix(nodes): typos in `strings.py` --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Millun Atluri <Millu@users.noreply.github.com>
2023-09-13 08:06:38 +00:00
# 2023 skunkworxdark (https://github.com/skunkworxdark)
import re
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UIComponent,
invocation,
invocation_output,
)
from .primitives import StringOutput
@invocation_output("string_pos_neg_output")
class StringPosNegOutput(BaseInvocationOutput):
"""Base class for invocations that output a positive and negative string"""
positive_string: str = OutputField(description="Positive string")
negative_string: str = OutputField(description="Negative string")
@invocation(
"string_split_neg",
title="String Split Negative",
tags=["string", "split", "negative"],
category="string",
version="1.0.0",
)
class StringSplitNegInvocation(BaseInvocation):
"""Splits string into two strings, inside [] goes into negative string everthing else goes into positive string. Each [ and ] character is replaced with a space"""
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringPosNegOutput:
p_string = ""
n_string = ""
brackets_depth = 0
escaped = False
for char in self.string or "":
if char == "[" and not escaped:
n_string += " "
brackets_depth += 1
elif char == "]" and not escaped:
brackets_depth -= 1
char = " "
elif brackets_depth > 0:
n_string += char
else:
p_string += char
# keep track of the escape char but only if it isn't escaped already
if char == "\\" and not escaped:
escaped = True
else:
escaped = False
return StringPosNegOutput(positive_string=p_string, negative_string=n_string)
@invocation_output("string_2_output")
class String2Output(BaseInvocationOutput):
"""Base class for invocations that output two strings"""
string_1: str = OutputField(description="string 1")
string_2: str = OutputField(description="string 2")
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.0")
class StringSplitInvocation(BaseInvocation):
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
delimiter: str = InputField(
default="", description="Delimiter to spilt with. blank will split on the first whitespace"
)
def invoke(self, context: InvocationContext) -> String2Output:
result = self.string.split(self.delimiter, 1)
if len(result) == 2:
part1, part2 = result
else:
part1 = result[0]
part2 = ""
return String2Output(string_1=part1, string_2=part2)
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.0")
class StringJoinInvocation(BaseInvocation):
"""Joins string left to string right"""
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.0")
class StringJoinThreeInvocation(BaseInvocation):
"""Joins string left to string middle to string right"""
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea)
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or "")))
@invocation(
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.0"
)
class StringReplaceInvocation(BaseInvocation):
"""Replaces the search string with the replace string"""
string: str = InputField(default="", description="String to work on", ui_component=UIComponent.Textarea)
search_string: str = InputField(default="", description="String to search for", ui_component=UIComponent.Textarea)
replace_string: str = InputField(
default="", description="String to replace the search", ui_component=UIComponent.Textarea
)
use_regex: bool = InputField(
default=False, description="Use search string as a regex expression (non regex is case insensitive)"
)
def invoke(self, context: InvocationContext) -> StringOutput:
pattern = self.search_string or ""
new_string = self.string or ""
if len(pattern) > 0:
if not self.use_regex:
# None regex so make case insensitve
pattern = "(?i)" + re.escape(pattern)
new_string = re.sub(pattern, (self.replace_string or ""), new_string)
return StringOutput(value=new_string)