mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c48fd9c083
Refine concept of "parameter" nodes to "primitives": - integer - float - string - boolean - image - latents - conditioning - color Each primitive has: - A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]` - A single output class. ex: `ImageOutput` - A collection output class. ex: `ImageCollectionOutput` - A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`) Plus a number of related changes: - Reorganize these into `primitives.py` - Update all nodes and logic to use primitives - Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically) - Update default graphs & tests - Regen frontend types & minor frontend tidy related to changes
86 lines
3.3 KiB
Python
86 lines
3.3 KiB
Python
from os.path import exists
|
|
from typing import Literal, Optional, Union
|
|
|
|
import numpy as np
|
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
|
from pydantic import validator
|
|
|
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
|
|
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UITypeHint, tags, title
|
|
|
|
|
|
@title("Dynamic Prompt")
|
|
@tags("prompt", "collection")
|
|
class DynamicPromptInvocation(BaseInvocation):
|
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
|
|
|
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
|
|
|
# Inputs
|
|
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
|
|
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
|
if self.combinatorial:
|
|
generator = CombinatorialPromptGenerator()
|
|
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
|
else:
|
|
generator = RandomPromptGenerator()
|
|
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
|
|
|
return StringCollectionOutput(collection=prompts)
|
|
|
|
|
|
@title("Prompts from File")
|
|
@tags("prompt", "file")
|
|
class PromptsFromFileInvocation(BaseInvocation):
|
|
"""Loads prompts from a text file"""
|
|
|
|
type: Literal["prompt_from_file"] = "prompt_from_file"
|
|
|
|
# Inputs
|
|
file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
|
|
pre_prompt: Optional[str] = InputField(
|
|
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
|
)
|
|
post_prompt: Optional[str] = InputField(
|
|
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
|
)
|
|
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
|
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
|
|
|
@validator("file_path")
|
|
def file_path_exists(cls, v):
|
|
if not exists(v):
|
|
raise ValueError(FileNotFoundError)
|
|
return v
|
|
|
|
def promptsFromFile(
|
|
self,
|
|
file_path: str,
|
|
pre_prompt: Union[str, None],
|
|
post_prompt: Union[str, None],
|
|
start_line: int,
|
|
max_prompts: int,
|
|
):
|
|
prompts = []
|
|
start_line -= 1
|
|
end_line = start_line + max_prompts
|
|
if max_prompts <= 0:
|
|
end_line = np.iinfo(np.int32).max
|
|
with open(file_path) as f:
|
|
for i, line in enumerate(f):
|
|
if i >= start_line and i < end_line:
|
|
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
|
if i >= end_line:
|
|
break
|
|
return prompts
|
|
|
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
|
prompts = self.promptsFromFile(
|
|
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
|
)
|
|
return StringCollectionOutput(collection=prompts)
|