mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): refactor parameter/primitive nodes
Refine concept of "parameter" nodes to "primitives": - integer - float - string - boolean - image - latents - conditioning - color Each primitive has: - A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]` - A single output class. ex: `ImageOutput` - A collection output class. ex: `ImageCollectionOutput` - A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`) Plus a number of related changes: - Reorganize these into `primitives.py` - Update all nodes and logic to use primitives - Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically) - Update default graphs & tests - Regen frontend types & minor frontend tidy related to changes
This commit is contained in:
@ -5,7 +5,7 @@ from typing import List, Literal, Union
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
BasicConditioningInfo,
|
||||
@ -32,13 +32,6 @@ from .baseinvocation import (
|
||||
from .model import ClipField
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: str = Field(description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
@ -51,16 +44,6 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("Compel Prompt")
|
||||
@tags("prompt", "compel")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
@ -80,7 +63,7 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
@ -163,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@ -303,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
@ -336,7 +319,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@ -361,7 +344,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
@ -384,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
|
Reference in New Issue
Block a user