feat(nodes): refactor parameter/primitive nodes

Refine concept of "parameter" nodes to "primitives":
- integer
- float
- string
- boolean
- image
- latents
- conditioning
- color

Each primitive has:
- A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]`
- A single output class. ex: `ImageOutput`
- A collection output class. ex: `ImageCollectionOutput`
- A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`)

Plus a number of related changes:
- Reorganize these into `primitives.py`
- Update all nodes and logic to use primitives
- Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically)
- Update default graphs & tests
- Regen frontend types & minor frontend tidy related to changes
This commit is contained in:
psychedelicious
2023-08-14 19:41:29 +10:00
parent f49fc7fb55
commit c48fd9c083
24 changed files with 887 additions and 666 deletions

View File

@ -1,40 +1,13 @@
from os.path import exists
from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from pydantic import validator
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UIComponent,
UITypeHint,
title,
tags,
)
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
from invokeai.app.invocations.primitives import StringCollectionOutput
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
type: Literal["prompt"] = "prompt"
prompt: str = OutputField(description="The output prompt")
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = OutputField(
description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection
)
count: int = OutputField(description="The size of the prompt collection")
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UITypeHint, tags, title
@title("Dynamic Prompt")
@ -49,7 +22,7 @@ class DynamicPromptInvocation(BaseInvocation):
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
@ -57,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation):
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
return StringCollectionOutput(collection=prompts)
@title("Prompts from File")
@ -70,10 +43,10 @@ class PromptsFromFileInvocation(BaseInvocation):
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
pre_prompt: Optional[str] = InputField(
description="String to prepend to each prompt", ui_component=UIComponent.Textarea
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)
post_prompt: Optional[str] = InputField(
description="String to append to each prompt", ui_component=UIComponent.Textarea
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
)
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
@ -84,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
def promptsFromFile(
self,
file_path: str,
pre_prompt: Union[str, None],
post_prompt: Union[str, None],
start_line: int,
max_prompts: int,
):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
@ -98,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
prompts = self.promptsFromFile(
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
return StringCollectionOutput(collection=prompts)