feat: node editor

squashed rebase on main after backendd refactor
This commit is contained in:
psychedelicious 2023-08-14 13:23:09 +10:00
parent d6c9bf5b38
commit f49fc7fb55
188 changed files with 8541 additions and 4660 deletions

View File

@ -38,7 +38,7 @@ import mimetypes
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
@ -134,6 +134,11 @@ def custom_openapi():
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, output_schema in ui_config_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__

View File

@ -3,15 +3,353 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Mapping,
Optional,
Type,
TypeVar,
Union,
get_args,
get_type_hints,
)
from pydantic import BaseConfig, BaseModel, Field
from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
core_metadata = "Optional core metadata to be written to image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
class Input(str, Enum):
"""
The type of input a field accepts.
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
are instantiated.
- `Input.Connection`: The field must have its value provided by a connection.
- `Input.Any`: The field may have its value provided either directly or by a connection.
"""
Connection = "connection"
Direct = "direct"
Any = "any"
class UITypeHint(str, Enum):
"""
Type hints for the UI.
If a field should be provided a data type that does not exactly match the python type of the field, \
use this to provide the type that should be used instead. See the node development docs for detail \
on adding a new field type, which involves client-side changes.
"""
Integer = "integer"
Float = "float"
Boolean = "boolean"
String = "string"
Enum = "enum"
Array = "array"
ImageField = "ImageField"
LatentsField = "LatentsField"
ConditioningField = "ConditioningField"
ControlField = "ControlField"
MainModelField = "MainModelField"
SDXLMainModelField = "SDXLMainModelField"
SDXLRefinerModelField = "SDXLRefinerModelField"
ONNXModelField = "ONNXModelField"
VaeModelField = "VaeModelField"
LoRAModelField = "LoRAModelField"
ControlNetModelField = "ControlNetModelField"
UNetField = "UNetField"
VaeField = "VaeField"
ClipField = "ClipField"
ColorField = "ColorField"
ImageCollection = "ImageCollection"
IntegerCollection = "IntegerCollection"
FloatCollection = "FloatCollection"
StringCollection = "StringCollection"
BooleanCollection = "BooleanCollection"
Collection = "Collection"
CollectionItem = "CollectionItem"
Seed = "Seed"
FilePath = "FilePath"
class UIComponent(str, Enum):
"""
The type of UI component to use for a field, used to override the default components, which are \
inferred from the field type.
"""
None_ = "none"
Textarea = "textarea"
Slider = "slider"
class _InputField(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
input: Input
ui_hidden: bool
ui_type_hint: Optional[UITypeHint]
ui_component: Optional[UIComponent]
class _OutputField(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
ui_hidden: bool
ui_type_hint: Optional[UITypeHint]
def InputField(
*args: Any,
default: Any = Undefined,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
allow_inf_nan: Optional[bool] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
input: Input = Input.Any,
ui_type_hint: Optional[UITypeHint] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
**kwargs: Any,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param Input input: [Input.Any] The kind of input this field requires. \
`Input.Direct` means a value must be provided on instantiation. \
`Input.Connection` means the value must be provided by a connection. \
`Input.Any` means either will do.
:param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
The UI will always render a suitable component, but sometimes you want something different than the default. \
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
For this case, you could provide `UIComponent.Textarea`.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
"""
return Field(
*args,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
input=input,
ui_type_hint=ui_type_hint,
ui_component=ui_component,
ui_hidden=ui_hidden,
**kwargs,
)
def OutputField(
*args: Any,
default: Any = Undefined,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
allow_inf_nan: Optional[bool] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
ui_type_hint: Optional[UITypeHint] = None,
ui_hidden: bool = False,
**kwargs: Any,
) -> Any:
"""
Creates an output field for an invocation output.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
"""
return Field(
*args,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
ui_type_hint=ui_type_hint,
ui_hidden=ui_hidden,
**kwargs,
)
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @tags and @title decorator logic. You probably want to use those
decorators, though you may add this class to a node definition to specify the title and tags.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
title: Optional[str] = Field(default=None, description="The display name of the node")
class InvocationContext:
services: InvocationServices
graph_execution_state_id: str
@ -39,6 +377,20 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses)
class RequiredConnectionException(Exception):
"""Raised when an field which requires a connection did not receive a value."""
def __init__(self, node_id: str, field_name: str):
super().__init__(f"Node {node_id} missing connections for field {field_name}")
class MissingInputException(Exception):
"""Raised when an field which requires some input, but did not receive a value."""
def __init__(self, node_id: str, field_name: str):
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
@ -76,70 +428,81 @@ class BaseInvocation(ABC, BaseModel):
def get_output_type(cls):
return signature(cls.invoke).return_annotation
class Config:
@staticmethod
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
@abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
# fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
# fmt: on
def __init__(self, **data):
# nodes may have required fields, that can accept input from connections
# on instantiation of the model, we need to exclude these from validation
restore = dict()
try:
field_names = list(self.__fields__.keys())
for field_name in field_names:
# if the field is required and may get its value from a connection, exclude it from validation
field = self.__fields__[field_name]
_input = field.field_info.extra.get("input", None)
if _input in [Input.Connection, Input.Any] and field.required:
if field_name not in data:
restore[field_name] = self.__fields__.pop(field_name)
# instantiate the node, which will validate the data
super().__init__(**data)
finally:
# restore the removed fields
for field_name, field in restore.items():
self.__fields__[field_name] = field
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
for field_name, field in self.__fields__.items():
_input = field.field_info.extra.get("input", None)
if field.required and not hasattr(self, field_name):
if _input == Input.Connection:
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
)
UIConfig: ClassVar[Type[UIConfigBase]]
# TODO: figure out a better way to provide these hints
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
type_hints: Dict[
str,
Literal[
"integer",
"float",
"boolean",
"string",
"enum",
"image",
"latents",
"model",
"control",
"image_collection",
"vae_model",
"lora_model",
],
]
tags: List[str]
title: str
T = TypeVar("T", bound=BaseInvocation)
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
def title(title: str) -> Callable[[Type[T]], Type[T]]:
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.title = title
return cls
return wrapper
class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.tags = list(tags)
return cls
`tags`
- A list of strings, used to categorise invocations.
`type_hints`
- A dict of field types which override the types in the invocation definition.
- Each key should be the name of one of the invocation's fields.
- Each value should be one of the valid types:
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
```python
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"initial_image": "image",
},
},
}
```
"""
schema_extra: CustomisedSchemaExtra
return wrapper

View File

@ -3,58 +3,78 @@
from typing import Literal
import numpy as np
from pydantic import Field, validator
from pydantic import validator
from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
type: Literal["int_collection_output"] = "int_collection_output"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
collection: list[int] = OutputField(
default=[], description="The int collection", ui_type_hint=UITypeHint.IntegerCollection
)
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
collection: list[float] = OutputField(
default=[], description="The float collection", ui_type_hint=UITypeHint.FloatCollection
)
class StringCollectionOutput(BaseInvocationOutput):
"""A collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(
default=[], description="The output strings", ui_type_hint=UITypeHint.StringCollection
)
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
type: Literal["image_collection"] = "image_collection"
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = Field(default=[], description="The output images")
class Config:
schema_extra = {"required": ["type", "collection"]}
collection: list[ImageField] = OutputField(
default=[], description="The output images", ui_type_hint=UITypeHint.ImageCollection
)
@title("Integer Range")
@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
}
start: int = InputField(default=0, description="The start of the range")
stop: int = InputField(default=10, description="The stop of the range")
step: int = InputField(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
@ -66,72 +86,56 @@ class RangeInvocation(BaseInvocation):
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@tags("range", "integer", "size", "collection")
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = Field(default=0, description="The start of the range")
size: int = Field(default=1, description="The number of values")
step: int = Field(default=1, description="The step of the range")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
}
start: int = InputField(default=0, description="The start of the range")
size: int = InputField(default=1, description="The number of values")
step: int = InputField(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@tags("range", "integer", "random", "collection")
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate")
seed: int = Field(
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate")
seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
}
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
@title("Image Collection")
@tags("image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
# fmt: off
type: Literal["image_collection"] = "image_collection"
# Inputs
images: list[ImageField] = Field(
default=[], description="The image collection to load"
images: list[ImageField] = InputField(
default=[], description="The image collection to load", ui_type_hint=UITypeHint.ImageCollection
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"type_hints": {
"title": "Image Collection",
"images": "image_collection",
}
},
}

View File

@ -1,29 +1,39 @@
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
from dataclasses import dataclass
from typing import List, Literal, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException
from pydantic import BaseModel, Field
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
BasicConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management import ModelPatcher, ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from ...backend.model_management.models import ModelNotFoundException
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
tags,
title,
)
from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
conditioning_name: str = Field(description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@ -47,23 +57,27 @@ class CompelOutput(BaseInvocationOutput):
# fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
# fmt: on
@title("Compel Prompt")
@tags("prompt", "compel")
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
clip: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
}
prompt: str = InputField(
default="",
description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea,
)
clip: ClipField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@ -270,27 +284,23 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec
@title("SDXL Compel Prompt")
@tags("sdxl", "compel", "prompt")
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip2 to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
}
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
original_width: int = InputField(default=1024, description="")
original_height: int = InputField(default=1024, description="")
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@ -333,28 +343,22 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
@title("SDXL Refiner Compel Prompt")
@tags("sdxl", "compel", "prompt")
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {"model": "model"},
},
}
style: str = InputField(
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
) # TODO: ?
original_width: int = InputField(default=1024, description="")
original_height: int = InputField(default=1024, description="")
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
@ -391,21 +395,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = Field(None, description="Clip with skipped layers")
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@title("CLIP Skip")
@tags("clipskip", "clip", "skip")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = Field(None, description="Clip to use")
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
}
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers

View File

@ -28,77 +28,27 @@ from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from ..models.image import ImageOutput, PILInvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
from ..models.image import ImageOutput
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
tuple(
[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
)
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
@ -110,9 +60,8 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -135,60 +84,39 @@ class ControlField(BaseModel):
raise ValueError("Control weights must be within -1 to 2 range")
return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
},
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info")
# fmt: on
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet")
@tags("controlnet")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "ControlNet",
"tags": ["controlnet", "latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
},
},
}
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type_hint=UITypeHint.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
@ -204,19 +132,13 @@ class ControlNetInvocation(BaseInvocation):
)
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="The image to process")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
}
# Inputs
image: ImageField = InputField(description="The image to process")
def run_processor(self, image):
# superclass just passes through image without processing
@ -255,20 +177,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Canny Processor")
@tags("controlnet", "canny")
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
}
# Input
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
)
high_threshold: int = InputField(
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
)
def run_processor(self, image):
canny_processor = CannyDetector()
@ -276,23 +198,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("HED (softedge) Processor")
@tags("controlnet", "hed", "softedge")
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# safe not supported in controlnet_aux v0.0.3
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
@ -307,21 +225,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Lineart Processor")
@tags("controlnet", "lineart")
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
coarse: bool = Field(default=False, description="Whether to use coarse mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
@ -331,23 +245,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Lineart Anime Processor")
@tags("controlnet", "lineart", "anime")
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Lineart Anime Processor",
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
@ -359,21 +266,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Openpose Processor")
@tags("controlnet", "openpose", "pose")
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
}
# Inputs
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
@ -386,22 +289,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Midas (Depth) Processor")
@tags("controlnet", "midas", "depth")
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
}
# Inputs
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
@ -415,20 +314,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Normal BAE Processor")
@tags("controlnet", "normal", "bae")
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
@ -438,22 +333,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("MLSD Processor")
@tags("controlnet", "mlsd")
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
@ -467,22 +358,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("PIDI Processor")
@tags("controlnet", "pidi")
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
safe: bool = Field(default=False, description="Whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
@ -496,26 +383,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Content Shuffle Processor")
@tags("controlnet", "contentshuffle")
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Content Shuffle Processor",
"tags": ["controlnet", "contentshuffle", "image", "processor"],
},
}
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
@ -531,17 +411,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Zoe (Depth) Processor")
@tags("controlnet", "zoe", "depth")
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
}
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
@ -549,20 +424,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Mediapipe Face Processor")
@tags("controlnet", "mediapipe", "face")
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
}
# Inputs
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
@ -574,23 +445,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Leres (Depth) Processor")
@tags("controlnet", "leres", "depth")
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
}
# Inputs
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
boost: bool = InputField(default=False, description="Whether to use boost mode")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
@ -605,21 +472,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
@title("Tile Resample Processor")
@tags("controlnet", "tile")
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Tile Resample Processor",
"tags": ["controlnet", "tile", "resample", "image", "processor"],
},
}
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(
@ -648,20 +510,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
@title("Segment Anything Processor")
@tags("controlnet", "segmentanything")
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Segment Anything Processor",
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
},
}
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")

View File

@ -5,40 +5,22 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .image import ImageOutput
class CvInvocationConfig(BaseModel):
"""Helper class to provide all OpenCV invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["cv", "image"],
},
}
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
@title("OpenCV Inpaint")
@tags("opencv", "inpaint")
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
}
image: ImageField = InputField(description="The image to inpaint")
mask: ImageField = InputField(description="The mask to use when inpainting")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -1,37 +1,30 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal, Optional
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from pydantic import Field
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
@title("Load Image")
@tags("image")
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
# fmt: off
# Metadata
type: Literal["load_image"] = "load_image"
# Inputs
image: Optional[ImageField] = Field(
default=None, description="The image to load"
)
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Load Image", "tags": ["image", "load"]},
}
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -43,18 +36,16 @@ class LoadImageInvocation(BaseInvocation):
)
@title("Show Image")
@tags("image")
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""
# Metadata
type: Literal["show_image"] = "show_image"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to show")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Show Image", "tags": ["image", "show"]},
}
image: ImageField = InputField(description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -70,24 +61,20 @@ class ShowImageInvocation(BaseInvocation):
)
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
@title("Crop Image")
@tags("image", "crop")
class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
# Metadata
type: Literal["img_crop"] = "img_crop"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
}
image: ImageField = InputField(description="The image to crop")
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -111,24 +98,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
@title("Paste Image")
@tags("image", "paste")
class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image."""
# fmt: off
# Metadata
type: Literal["img_paste"] = "img_paste"
# Inputs
base_image: Optional[ImageField] = Field(default=None, description="The base image")
image: Optional[ImageField] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
}
base_image: ImageField = InputField(description="The base image")
image: ImageField = InputField(description="The image to paste")
mask: Optional[ImageField] = InputField(
default=None,
description="The mask to use when pasting",
)
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name)
@ -164,21 +150,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
@title("Mask from Alpha")
@tags("image", "mask")
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
# fmt: off
# Metadata
type: Literal["tomask"] = "tomask"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
}
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -203,21 +185,17 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
@title("Multiply Images")
@tags("image", "multiply")
class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off
# Metadata
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
}
image1: ImageField = InputField(description="The first image to multiply")
image2: ImageField = InputField(description="The second image to multiply")
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(self.image1.image_name)
@ -244,21 +222,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
@title("Extract Image Channel")
@tags("image", "channel")
class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image."""
# fmt: off
# Metadata
type: Literal["img_chan"] = "img_chan"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
}
image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -284,21 +258,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
@title("Convert Image Mode")
@tags("image", "convert")
class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode."""
# fmt: off
# Metadata
type: Literal["img_conv"] = "img_conv"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
}
image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -321,22 +291,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
@title("Blur Image")
@tags("image", "blur")
class ImageBlurInvocation(BaseInvocation):
"""Blurs an image"""
# fmt: off
# Metadata
type: Literal["img_blur"] = "img_blur"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
}
image: ImageField = InputField(description="The image to blur")
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
# Metadata
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -382,23 +349,19 @@ PIL_RESAMPLING_MAP = {
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
@title("Resize Image")
@tags("image", "resize")
class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions"""
# fmt: off
# Metadata
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
}
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -426,22 +389,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
@title("Scale Image")
@tags("image", "scale")
class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor"""
# fmt: off
# Metadata
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
}
image: ImageField = InputField(description="The image to scale")
scale_factor: float = InputField(
default=2.0,
gt=0,
description="The factor by which to scale the image",
)
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -471,22 +434,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
@title("Lerp Image")
@tags("image", "lerp")
class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
# fmt: off
# Metadata
type: Literal["img_lerp"] = "img_lerp"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
}
image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -512,25 +471,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
@title("Inverse Lerp Image")
@tags("image", "ilerp")
class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
# Metadata
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Image Inverse Linear Interpolation",
"tags": ["image", "linear", "interpolation", "inverse"],
},
}
image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -556,21 +508,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
@title("Blur NSFW Image")
@tags("image", "nsfw")
class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images"""
# fmt: off
# Metadata
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
}
image: ImageField = InputField(description="The image to check")
metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -607,22 +557,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
return caution.resize((caution.width // 2, caution.height // 2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
@title("Add Invisible Watermark")
@tags("image", "watermark")
class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image"""
# fmt: off
# Metadata
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
}
image: ImageField = InputField(description="The image to check")
text: str = InputField(default="InvokeAI", description="Watermark text")
metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -644,19 +592,21 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
)
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
@title("Mask Edge")
@tags("image", "mask", "inpaint")
class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image"""
# fmt: off
type: Literal["mask_edge"] = "mask_edge"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
edge_size: int = Field(description="The size of the edge")
edge_blur: int = Field(description="The amount of blur on the edge")
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
# fmt: on
image: ImageField = InputField(description="The image to apply the mask to")
edge_size: int = InputField(description="The size of the edge")
edge_blur: int = InputField(description="The amount of blur on the edge")
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
high_threshold: int = InputField(
description="Second threshold for the hysteresis procedure in Canny edge detection"
)
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.image.image_name)
@ -690,21 +640,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
)
class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
@title("Combine Mask")
@tags("image", "mask", "multiply")
class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["mask_combine"] = "mask_combine"
# Inputs
mask1: ImageField = Field(default=None, description="The first mask to combine")
mask2: ImageField = Field(default=None, description="The second image to combine")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
}
mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = InputField(description="The second image to combine")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
@ -728,7 +673,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
)
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
@title("Color Correct")
@tags("image", "color")
class ColorCorrectInvocation(BaseInvocation):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
@ -736,10 +683,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct"
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
# Inputs
image: ImageField = InputField(description="The image to color-correct")
reference: ImageField = InputField(description="Reference image for color-correction")
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
@ -833,16 +781,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
)
@title("Image Hue Adjustment")
@tags("image", "hue", "hsl")
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
image: ImageField = InputField(description="The image to adjust")
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
@ -877,16 +825,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
)
@title("Image Luminosity Adjustment")
@tags("image", "luminosity", "hsl")
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
# fmt: on
image: ImageField = InputField(description="The image to adjust")
luminosity: float = InputField(
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
@ -925,16 +875,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
)
@title("Image Saturation Adjustment")
@tags("image", "saturation", "hsl")
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
image: ImageField = InputField(description="The image to adjust")
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
)
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, title, tags
def infill_methods() -> list[str]:
@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@title("Solid Color Infill")
@tags("image", "inpaint")
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: ColorField = Field(
# Inputs
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
)
@title("Tile Infill")
@tags("image", "inpaint")
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
# Input
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
)
@title("PatchMatch Infill")
@tags("image", "inpaint")
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
}
# Inputs
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -13,7 +13,8 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize
@ -23,6 +24,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
@ -32,9 +34,20 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
@ -46,8 +59,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
seed: Optional[int] = Field(description="Seed used to generate this latents")
latents_name: str = Field(description="The name of the latents")
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
class Config:
schema_extra = {"required": ["latents_name"]}
@ -56,14 +69,14 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
# fmt: off
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
# fmt: on
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
@ -111,30 +124,36 @@ def get_scheduler(
return scheduler
@title("Denoise Latents")
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
type: Literal["denoise_latents"] = "denoise_latents"
# Inputs
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(
default=7.5,
ge=1,
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
mask: Optional[ImageField] = Field(
None,
description="Mask",
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type_hint=UITypeHint.Float
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[ImageField] = InputField(
default=None,
description=FieldDescriptions.mask,
)
@validator("cfg_scale")
@ -149,20 +168,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Denoise Latents",
"tags": ["denoise", "latents"],
"type_hints": {
"model": "model",
"control": "control",
"cfg_scale": "number",
},
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
@ -474,29 +479,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
# Latent to image
@title("Latents to Image")
@tags("latents", "image", "vae")
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image"
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Latents To Image",
"tags": ["latents", "image"],
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -574,24 +579,30 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@title("Resize Latents")
@tags("latents", "resize")
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
}
width: int = InputField(
ge=64,
multiple_of=8,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=8,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -616,23 +627,21 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Scale Latents")
@tags("latents", "resize")
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
}
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -658,22 +667,23 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Image to Latents")
@tags("latents", "image", "vae")
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
}
image: ImageField = InputField(
description="The image to encode",
)
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:

View File

@ -2,134 +2,104 @@
from typing import Literal
from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
InvocationConfig,
OutputField,
tags,
title,
)
class MathInvocationConfig(BaseModel):
"""Helper class to provide all math invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["math"],
}
}
class IntOutput(BaseInvocationOutput):
"""An integer output"""
# fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
# fmt: on
a: int = OutputField(default=None, description="The output integer")
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
a: float = OutputField(default=None, description="The output float")
class AddInvocation(BaseInvocation, MathInvocationConfig):
@title("Add Integers")
@tags("math")
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
# fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add", "tags": ["math", "add"]},
}
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
@title("Subtract Integers")
@tags("math")
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
# fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
}
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
@title("Multiply Integers")
@tags("math")
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
# fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
}
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation, MathInvocationConfig):
@title("Divide Integers")
@tags("math")
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
# fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Divide", "tags": ["math", "divide"]},
}
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
@title("Random Integer")
@tags("math")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
# fmt: off
type: Literal["rand_int"] = "rand_int"
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
}
# Inputs
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(self.low, self.high))

View File

@ -1,18 +1,21 @@
from typing import Literal, Optional, Union
from typing import Literal, Optional
from pydantic import Field
from ...version import __version__
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InputField,
InvocationContext,
tags,
title,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
class LoRAMetadataField(BaseModelExcludeNull):
"""LoRA metadata for an image generated in InvokeAI."""
@ -43,37 +46,37 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field(
vae: Optional[VAEModelField] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
strength: Union[float, None] = Field(
strength: Optional[float] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Optional[float] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Union[float, None] = Field(
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_negative_aesthetic_store: Union[float, None] = Field(
refiner_negative_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModelExcludeNull):
@ -94,66 +97,83 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
metadata: CoreMetadata = Field(description="The core metadata for the image")
@title("Metadata Accumulator")
@tags("metadata")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(
generation_mode: str = InputField(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(
positive_prompt: str = InputField(description="The positive prompt parameter")
negative_prompt: str = InputField(description="The negative prompt parameter")
width: int = InputField(description="The width parameter")
height: int = InputField(description="The height parameter")
seed: int = InputField(description="The seed used for noise generation")
rand_device: str = InputField(description="The device used for random number generation")
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
steps: int = InputField(description="The number of steps used for inference")
scheduler: str = InputField(description="The scheduler used for inference")
clip_skip: int = InputField(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
vae: Union[VAEModelField, None] = Field(
init_image: Optional[str] = InputField(
default=None,
description="The name of the initial image",
)
vae: Optional[VAEModelField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
positive_style_prompt: Optional[str] = InputField(
default=None,
description="The positive style prompt parameter",
)
negative_style_prompt: Optional[str] = InputField(
default=None,
description="The negative style prompt parameter",
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
refiner_model: Optional[MainModelField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
refiner_cfg_scale: Optional[float] = InputField(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_score: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
refiner_steps: Optional[int] = InputField(
default=None,
description="The number of steps used for the refiner",
)
refiner_negative_aesthetic_score: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
refiner_scheduler: Optional[str] = InputField(
default=None,
description="The scheduler used for the refiner",
)
refiner_positive_aesthetic_store: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_negative_aesthetic_store: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_start: Optional[float] = InputField(
default=None,
description="The start value used for refiner denoising",
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"],
},
}
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""

View File

@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
class ModelInfo(BaseModel):
@ -39,13 +50,11 @@ class VaeField(BaseModel):
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class MainModelField(BaseModel):
@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
@title("Main Model Loader")
@tags("model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
model: MainModelField = Field(description="The model to load")
# Inputs
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
# fmt: on
@title("LoRA Loader")
@tags("lora", "model")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
# Inputs
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
clip: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
"""SDXL LoRA Loader Output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on
@title("SDXL LoRA Loader")
@tags("sdxl", "lora", "model")
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
)
clip: Optional[ClipField] = Field(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
)
clip2: Optional[ClipField] = Field(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
# fmt: on
# Outputs
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("VAE Loader")
@tags("vae", "model")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {"vae_model": "vae_model"},
},
}
# Inputs
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModelField, title="VAE"
)
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model

View File

@ -1,19 +1,24 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
import math
from typing import Literal
from pydantic import Field, validator
import torch
from invokeai.app.invocations.latent import LatentsField
from pydantic import validator
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
"""
@ -61,14 +66,12 @@ Nodes
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
# fmt: off
type: Literal["noise_output"] = "noise_output"
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
# fmt: on
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
)
@title("Noise")
@tags("latents", "noise")
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(
seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed to use",
description=FieldDescriptions.seed,
default_factory=get_random_seed,
)
width: int = Field(
width: int = InputField(
default=512,
multiple_of=8,
gt=0,
description="The width of the resulting noise",
description=FieldDescriptions.width,
)
height: int = Field(
height: int = InputField(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting noise",
description=FieldDescriptions.height,
)
use_cpu: bool = Field(
use_cpu: bool = InputField(
default=True,
description="Use CPU for noise generation (for reproducible results across platforms)",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Noise",
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""

View File

@ -1,37 +1,44 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
import inspect
import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import re
import inspect
from pydantic import BaseModel, Field, validator
import torch
import numpy as np
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ONNXModelPatcher
from ...backend.util import choose_torch_device
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from pydantic import BaseModel, Field, validator
from tqdm import tqdm
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend import BaseModelType, ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from tqdm import tqdm
from .model import ClipField
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
from .compel import CompelOutput
from ...backend.util import choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UIComponent,
UITypeHint,
tags,
title,
)
from .compel import CompelOutput, ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
from .model import ClipField, ModelInfo, UNetField, VaeField
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
@ -51,11 +58,13 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@title("ONNX Prompt (Raw)")
@tags("onnx", "prompt")
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = Field(default="", description="Prompt")
clip: ClipField = Field(None, description="Clip to use")
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
@ -134,25 +143,48 @@ class ONNXPromptInvocation(BaseInvocation):
# Text to image
@title("ONNX Text to Latents")
@tags("latents", "inference", "txt2img", "onnx")
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond,
input=Input.Connection,
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond,
input=Input.Connection,
)
noise: LatentsField = InputField(
description=FieldDescriptions.noise,
input=Input.Connection,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5,
ge=1,
description=FieldDescriptions.cfg_scale,
ui_type_hint=UITypeHint.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
)
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
)
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
description=FieldDescriptions.control,
ui_type_hint=UITypeHint.ControlField,
)
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
@validator("cfg_scale")
def ge_one(cls, v):
@ -166,20 +198,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -300,26 +318,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# Latent to image
@title("ONNX Latents to Image")
@tags("latents", "image", "vae", "onnx")
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image"
latents: LatentsField = InputField(
description=FieldDescriptions.denoised_latents,
input=Input.Connection,
)
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
},
}
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
metadata: Optional[CoreMetadata] = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
@ -373,89 +393,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
# fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
# fmt: on
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
model_name = "stable-diffusion-v1-5"
base_model = BaseModelType.StableDiffusion1
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
):
raise Exception(f"Unkown model name: {model_name}!")
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeEncoder,
),
),
)
class OnnxModelField(BaseModel):
"""Onnx model field"""
@ -464,22 +408,17 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type")
@title("ONNX Model Loader")
@tags("onnx", "model")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
model: OnnxModelField = Field(description="The model to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Onnx Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
# Inputs
model: OnnxModelField = InputField(
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModelField
)
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model

View File

@ -1,73 +1,63 @@
import io
from typing import Literal, Optional, Any
from typing import Literal, Optional
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
from easing_functions import (
LinearInOut,
QuadEaseInOut,
QuadEaseIn,
QuadEaseOut,
CubicEaseInOut,
CubicEaseIn,
CubicEaseOut,
QuarticEaseInOut,
QuarticEaseIn,
QuarticEaseOut,
QuinticEaseInOut,
QuinticEaseIn,
QuinticEaseOut,
SineEaseInOut,
SineEaseIn,
SineEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
ExponentialEaseInOut,
ExponentialEaseIn,
ExponentialEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
BackEaseIn,
BackEaseInOut,
BackEaseOut,
BounceEaseIn,
BounceEaseInOut,
BounceEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
CubicEaseIn,
CubicEaseInOut,
CubicEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
ExponentialEaseIn,
ExponentialEaseInOut,
ExponentialEaseOut,
LinearInOut,
QuadEaseIn,
QuadEaseInOut,
QuadEaseOut,
QuarticEaseIn,
QuarticEaseInOut,
QuarticEaseOut,
QuinticEaseIn,
QuinticEaseInOut,
QuinticEaseOut,
SineEaseIn,
SineEaseInOut,
SineEaseOut,
)
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from pydantic import BaseModel, Field
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .collections import FloatCollectionOutput
@title("Float Range")
@tags("math", "range")
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
}
start: float = InputField(default=5, description="The first value of the range")
stop: float = InputField(default=10, description="The last value of the range")
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
@ -108,37 +98,32 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any]
@title("Step Param Easing")
@tags("step", "easing")
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
num_steps: int = InputField(default=20, description="number of denoising steps")
start_value: float = InputField(default=0.0, description="easing starting value")
end_value: float = InputField(default=1.0, description="easing ending value")
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
mirror: bool = InputField(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
}
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = InputField(default=False, description="show easing plot")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False

View File

@ -2,82 +2,80 @@
from typing import Literal
from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
tags,
title,
)
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
@title("Integer Parameter")
@tags("integer")
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
# fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
}
# Inputs
a: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
@title("Float Parameter")
@tags("float")
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
# fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
}
# Inputs
param: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)
return FloatOutput(a=self.param)
class StringOutput(BaseInvocationOutput):
"""A string output"""
type: Literal["string_output"] = "string_output"
text: str = Field(default=None, description="The output string")
text: str = OutputField(description="The output string")
@title("String Parameter")
@tags("string")
class ParamStringInvocation(BaseInvocation):
"""A string parameter"""
type: Literal["param_string"] = "param_string"
text: str = Field(default="", description="The string value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
}
# Inputs
text: str = InputField(default="", description="The string value")
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
@title("Prompt Parameter")
@tags("prompt")
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
prompt: str = Field(default="", description="The prompt value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
}
# Inputs
prompt: str = InputField(default="", description="The prompt value")
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@ -2,56 +2,52 @@ from os.path import exists
from typing import Literal, Optional
import numpy as np
from pydantic import Field, validator
from pydantic import validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UIComponent,
UITypeHint,
title,
tags,
)
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
# fmt: off
type: Literal["prompt"] = "prompt"
prompt: str = Field(default=None, description="The output prompt")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"prompt",
]
}
prompt: str = OutputField(description="The output prompt")
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
prompt_collection: list[str] = OutputField(
description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection
)
count: int = OutputField(description="The size of the prompt collection")
@title("Dynamic Prompt")
@tags("prompt", "collection")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
}
# Inputs
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
@ -64,24 +60,23 @@ class DynamicPromptInvocation(BaseInvocation):
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
@title("Prompts from File")
@tags("prompt", "file")
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
type: Literal["prompt_from_file"] = "prompt_from_file"
# Inputs
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
}
file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
pre_prompt: Optional[str] = InputField(
description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)
post_prompt: Optional[str] = InputField(
description="String to append to each prompt", ui_component=UIComponent.Textarea
)
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
@validator("file_path")
def file_path_exists(cls, v):

View File

@ -1,55 +1,55 @@
import torch
from typing import Literal
from pydantic import Field
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
# fmt: on
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("SDXL Main Model Loader")
@tags("model", "sdxl")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
model: MainModelField = Field(description="The model to load")
# Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModelField
)
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
)
@title("SDXL Refiner Model Loader")
@tags("model", "sdxl", "refiner")
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load")
# Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type_hint=UITypeHint.SDXLRefinerModelField,
)
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "refiner_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name

View File

@ -6,12 +6,11 @@ import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .image import ImageOutput
# TODO: Populate this from disk?
@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
]
@title("Upscale (RealESRGAN)")
@tags("esrgan", "upscale")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
}
# Inputs
image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -5,14 +5,13 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
image_name: str = Field(description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
@ -36,17 +35,6 @@ class ProgressImage(BaseModel):
dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""

View File

@ -3,16 +3,7 @@
import copy
import itertools
import uuid
from typing import (
Annotated,
Any,
Literal,
Optional,
Union,
get_args,
get_origin,
get_type_hints,
)
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
@ -22,7 +13,11 @@ from ..invocations import *
from ..invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Input,
InputField,
InvocationContext,
OutputField,
UITypeHint,
)
# in 3.10 this would be "from types import NoneType"
@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
type: Literal["iterate_output"] = "iterate_output"
item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
"required": [
"type",
"item",
]
}
item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type_hint=UITypeHint.CollectionItem
)
# TODO: Fill this out and move to invocations
@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
index: int = Field(description="The index, will be provided on executed iterators", default=0)
collection: list[Any] = InputField(
description="The list of items to iterate over", default_factory=list, ui_type_hint=UITypeHint.Collection
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
class CollectInvocationOutput(BaseInvocationOutput):
type: Literal["collect_output"] = "collect_output"
collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
"required": [
"type",
"collection",
]
}
collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type_hint=UITypeHint.Collection
)
class CollectInvocation(BaseInvocation):
@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
type: Literal["collect"] = "collect"
item: Any = Field(
item: Any = InputField(
description="The item to collect (all inputs must be of the same type)",
default=None,
ui_type_hint=UITypeHint.CollectionItem,
title="Collection Item",
input=Input.Connection,
)
collection: list[Any] = Field(
description="The collection, will be provided on execution",
default_factory=list,
collection: list[Any] = InputField(
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:

View File

@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke(
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
outputs = invocation.invoke_internal(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,

View File

@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
parsed = parse_raw_as(item_type, item)
return parsed
def set(self, item: T):
try:

View File

@ -61,6 +61,7 @@
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1",

View File

@ -0,0 +1,34 @@
export const COLORS = {
reset: '\x1b[0m',
bright: '\x1b[1m',
dim: '\x1b[2m',
underscore: '\x1b[4m',
blink: '\x1b[5m',
reverse: '\x1b[7m',
hidden: '\x1b[8m',
fg: {
black: '\x1b[30m',
red: '\x1b[31m',
green: '\x1b[32m',
yellow: '\x1b[33m',
blue: '\x1b[34m',
magenta: '\x1b[35m',
cyan: '\x1b[36m',
white: '\x1b[37m',
gray: '\x1b[90m',
crimson: '\x1b[38m',
},
bg: {
black: '\x1b[40m',
red: '\x1b[41m',
green: '\x1b[42m',
yellow: '\x1b[43m',
blue: '\x1b[44m',
magenta: '\x1b[45m',
cyan: '\x1b[46m',
white: '\x1b[47m',
gray: '\x1b[100m',
crimson: '\x1b[48m',
},
};

View File

@ -1,23 +1,83 @@
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';
import { COLORS } from './colors.js';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
async function main() {
process.stdout.write(
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
transform: (schemaObject) => {
transform: (schemaObject, metadata) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
/**
* Because invocations may have required fields that accept connection input, the generated
* types may be incorrect.
*
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
* connection input, it should be optional on instantiation of the field.
*
* To handle this, the schema exposes an `input` property that can be used to determine if the
* field accepts connection input. If it does, we can make the field optional.
*/
// Check if we are generating types for an invocation
const isInvocationPath = metadata.path.match(
/^#\/components\/schemas\/\w*Invocation$/
);
const hasInvocationProperties =
schemaObject.properties &&
['id', 'is_intermediate', 'type'].every(
(prop) => prop in schemaObject.properties
);
if (isInvocationPath && hasInvocationProperties) {
// We only want to make fields optional if they are required
if (!Array.isArray(schemaObject?.required)) {
schemaObject.required = ['id', 'type'];
return;
}
schemaObject.required.forEach((prop) => {
const acceptsConnection = ['any', 'connection'].includes(
schemaObject.properties?.[prop]?.['input']
);
if (acceptsConnection) {
// remove this prop from the required array
const invocationName = metadata.path.split('/').pop();
console.log(
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
);
schemaObject.required = schemaObject.required.filter(
(r) => r !== prop
);
}
});
schemaObject.required = [
...new Set(schemaObject.required.concat(['id', 'type'])),
];
return;
}
// if (
// 'input' in schemaObject &&
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
// ) {
// schemaObject.required = false;
// }
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(` OK!\r\n`);
process.stdout.write(`\nOK!\r\n`);
}
main();

View File

@ -1,8 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
ctrlKeyPressed,
metaKeyPressed,
shiftKeyPressed,
} from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
@ -16,11 +20,11 @@ import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
(hotkeys, ui) => {
const { shift } = hotkeys;
[stateSelector],
({ hotkeys, ui }) => {
const { shift, ctrl, meta } = hotkeys;
const { shouldPinParametersPanel, shouldPinGallery } = ui;
return { shift, shouldPinGallery, shouldPinParametersPanel };
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
},
{
memoizeOptions: {
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
globalHotkeysSelector
);
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
useAppSelector(globalHotkeysSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
useHotkeys(
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
} else {
shift && dispatch(shiftKeyPressed(false));
}
if (isHotkeyPressed('ctrl')) {
!ctrl && dispatch(ctrlKeyPressed(true));
} else {
ctrl && dispatch(ctrlKeyPressed(false));
}
if (isHotkeyPressed('meta')) {
!meta && dispatch(metaKeyPressed(true));
} else {
meta && dispatch(metaKeyPressed(false));
}
},
{ keyup: true, keydown: true },
[shift]
[shift, ctrl, meta]
);
useHotkeys('o', () => {

View File

@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading';
import '../../i18n';
import ImageDndContext from './ImageDnd/ImageDndContext';
import AppDndContext from '../../features/dnd/components/AppDndContext';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -80,9 +80,9 @@ const InvokeAIUI = ({
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<ImageDndContext>
<AppDndContext>
<App config={config} headerComponent={headerComponent} />
</ImageDndContext>
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>
</Provider>

View File

@ -19,7 +19,8 @@ type LoggerNamespace =
| 'nodes'
| 'system'
| 'socketio'
| 'session';
| 'session'
| 'dnd';
export const logger = (namespace: LoggerNamespace) =>
$logger.get().child({ namespace });

View File

@ -15,7 +15,7 @@ export const actionsDenylist = [
'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
// every time user presses shift
'hotkeys/shiftKeyPressed',
// 'hotkeys/shiftKeyPressed',
// this happens after every state change
'@@REMEMBER_PERSISTED',
];

View File

@ -1,16 +1,20 @@
import { createAction } from '@reduxjs/toolkit';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
fieldImageValueChanged,
workflowExposedFieldAdded,
} from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
startAppListening({
actionCreator: dndDropped,
effect: async (action, { dispatch }) => {
const log = logger('images');
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') {
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug(
{ activeData: parseify(activeData), overData: parseify(overData) },
'Node field dropped'
);
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
if (
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
activeData.payloadType === 'NODE_FIELD'
) {
const { nodeId, field } = activeData.payload;
dispatch(
workflowExposedFieldAdded({
nodeId,
fieldName: field.name,
})
);
}
/**
* Image dropped on current image
*/
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
fieldImageValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,

View File

@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
dispatch(
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
);
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,

View File

@ -15,12 +15,21 @@ import {
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
import {
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
predicate: (state, action) =>
predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().generation.model;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const result = zMainOrOnnxModel.safeParse(firstModel);
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zMainOrOnnxModel.safeParse(models[0]);
if (!result.success) {
log.error(
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
},
});
startAppListening({
predicate: (state, action) =>
predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
const result = zSDXLRefinerModel.safeParse(firstModel);
const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zSDXLRefinerModel.safeParse(models[0]);
if (!result.success) {
log.error(

View File

@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
log.debug({ schemaJSON }, 'Received OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: () => {
effect: (action) => {
const log = logger('system');
log.error('Problem dereferencing OpenAPI Schema');
log.error(
{ error: parseify(action.error) },
'Problem retrieving OpenAPI Schema'
);
},
});
};

View File

@ -19,7 +19,7 @@ import {
} from 'services/events/actions';
import { startAppListening } from '../..';
const nodeDenylist = ['dataURL_image'];
const nodeDenylist = ['load_image'];
export const addInvocationCompleteEventListener = () => {
startAppListening({

View File

@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
const log = logger('session');
const state = getState();
const graph = buildNodesGraph(state);
const graph = buildNodesGraph(state.nodes);
dispatch(nodesGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Nodes graph built');

View File

@ -1,86 +1,7 @@
import {
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';
// These are old types from the model management UI
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
// export type Model = {
// status: ModelStatus;
// description: string;
// weights: string;
// config?: string;
// vae?: string;
// width?: number;
// height?: number;
// default?: boolean;
// format?: string;
// };
// export type DiffusersModel = {
// status: ModelStatus;
// description: string;
// repo_id?: string;
// path?: string;
// vae?: {
// repo_id?: string;
// path?: string;
// };
// format?: string;
// default?: boolean;
// };
// export type ModelList = Record<string, Model & DiffusersModel>;
// export type FoundModel = {
// name: string;
// location: string;
// };
// export type InvokeModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// config: string | undefined;
// weights: string | undefined;
// vae: string | undefined;
// width: number | undefined;
// height: number | undefined;
// default: boolean | undefined;
// format: string | undefined;
// };
// export type InvokeDiffusersModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// repo_id: string | undefined;
// path: string | undefined;
// default: boolean | undefined;
// format: string | undefined;
// vae: {
// repo_id: string | undefined;
// path: string | undefined;
// };
// };
// export type InvokeModelConversionProps = {
// model_name: string;
// save_location: string;
// custom_location: string | null;
// };
// export type InvokeModelMergingProps = {
// models_to_merge: string[];
// alpha: number;
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
// force: boolean;
// merged_model_name: string;
// model_merge_save_path: string | null;
// };
/**
* A disable-able application feature
*/

View File

@ -6,10 +6,6 @@ import {
useColorMode,
useColorModeValue,
} from '@chakra-ui/react';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIIconButton from 'common/components/IAIIconButton';
import {
IAILoadingImageFallback,
@ -17,6 +13,10 @@ import {
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import {
MouseEvent,
@ -157,11 +157,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
<IAILoadingImageFallback image={imageDTO} />
)
}
width={imageDTO.width}
height={imageDTO.height}
onError={onError}
draggable={false}
sx={{
w: imageDTO.width,
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
@ -213,13 +212,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick}
/>
)}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
@ -244,6 +236,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}}
/>
)}
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
</Flex>
)}
</ImageContextMenu>

View File

@ -1,22 +1,19 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDraggableData,
useDraggable,
} from 'app/components/ImageDnd/typesafeDnd';
import { MouseEvent, memo, useRef } from 'react';
import { Box, BoxProps } from '@chakra-ui/react';
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { TypesafeDraggableData } from 'features/dnd/types';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = {
type IAIDraggableProps = BoxProps & {
disabled?: boolean;
data?: TypesafeDraggableData;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, onClick } = props;
const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggable({
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current,
disabled,
data,
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
return (
<Box
onClick={onClick}
ref={setNodeRef}
position="absolute"
w="full"
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
insetInlineStart={0}
{...attributes}
{...listeners}
{...rest}
/>
);
};

View File

@ -1,9 +1,7 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDroppableData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
import { TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { AnimatePresence } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({
const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current,
disabled,
data,

View File

@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
type IAINoImageFallbackProps = {
label?: string;
icon?: As;
icon?: As | null;
boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
};
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
...props.sx,
}}
>
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>}
</Flex>
);

View File

@ -1,10 +1,13 @@
import {
Flex,
FormControl,
FormControlProps,
FormHelperText,
FormLabel,
FormLabelProps,
Switch,
SwitchProps,
Text,
Tooltip,
} from '@chakra-ui/react';
import { memo } from 'react';
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
formControlProps?: FormControlProps;
formLabelProps?: FormLabelProps;
tooltip?: string;
helperText?: string;
}
/**
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
formControlProps,
formLabelProps,
tooltip,
helperText,
...rest
} = props;
return (
@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
<FormControl
isDisabled={isDisabled}
width={width}
display="flex"
alignItems="center"
{...formControlProps}
>
{label && (
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
pe: 4,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
<Switch {...rest} />
<Flex sx={{ flexDir: 'column', w: 'full' }}>
<Flex sx={{ alignItems: 'center', w: 'full' }}>
{label && (
<FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
pe: 4,
}}
{...formLabelProps}
>
{label}
</FormLabel>
)}
<Switch {...rest} />
</Flex>
{helperText && (
<FormHelperText>
<Text variant="subtext">{helperText}</Text>
</FormHelperText>
)}
</Flex>
</FormControl>
</Tooltip>
);

View File

@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
] = useToken('colors', [
'base.50',
'base.100',
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
'accent.850',
'accent.900',
'accent.950',
'baseAlpha.50',
'baseAlpha.100',
'baseAlpha.150',
'baseAlpha.200',
'baseAlpha.250',
'baseAlpha.300',
'baseAlpha.350',
'baseAlpha.400',
'baseAlpha.450',
'baseAlpha.500',
'baseAlpha.550',
'baseAlpha.600',
'baseAlpha.650',
'baseAlpha.700',
'baseAlpha.750',
'baseAlpha.800',
'baseAlpha.850',
'baseAlpha.900',
'baseAlpha.950',
'accentAlpha.50',
'accentAlpha.100',
'accentAlpha.150',
'accentAlpha.200',
'accentAlpha.250',
'accentAlpha.300',
'accentAlpha.350',
'accentAlpha.400',
'accentAlpha.450',
'accentAlpha.500',
'accentAlpha.550',
'accentAlpha.600',
'accentAlpha.650',
'accentAlpha.700',
'accentAlpha.750',
'accentAlpha.800',
'accentAlpha.850',
'accentAlpha.900',
'accentAlpha.950',
]);
return {
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
accent850,
accent900,
accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
};
};

View File

@ -1,4 +1,10 @@
/**
* Serialize an object to JSON and back to a new object
*/
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
export const parseify = (obj: unknown) => {
try {
return JSON.parse(JSON.stringify(obj));
} catch {
return 'Error parsing object';
}
};

View File

@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
} from 'features/dnd/types';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';

View File

@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/**
* Any ControlNet Processor node, with its parameters flagged as required
*/
export type RequiredControlNetProcessorNode =
export type RequiredControlNetProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
| RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation;
| RequiredZoeDepthImageProcessorInvocation,
'id'
>;
/**
* Type guard for CannyImageProcessorInvocation

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
input.type === 'ImageField' && input.value?.image_name === image_name
);
});

View File

@ -6,23 +6,18 @@ import {
useSensor,
useSensors,
} from '@dnd-kit/core';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { logger } from 'app/logging/logger';
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
import { DndContextTypesafe } from './DndContextTypesafe';
import DragPreview from './DragPreview';
import {
DndContext,
DragEndEvent,
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => {
const AppDndContext = (props: PropsWithChildren) => {
const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const log = logger('images');
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragStart = useCallback(
(event: DragStartEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag started');
log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag started'
);
const activeData = event.active.data.current;
if (!activeData) {
return;
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag ended');
log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag ended'
);
const overData = event.over?.data.current;
if (!activeDragData || !overData) {
return;
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const sensors = useSensors(mouseSensor, touchSensor);
const scaledModifier = useScaledModifer();
return (
<DndContext
<DndContextTypesafe
onDragStart={handleDragStart}
onDragEnd={handleDragEnd}
sensors={sensors}
collisionDetection={pointerWithin}
autoScroll={false}
>
{props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
<DragOverlay
dropAnimation={null}
modifiers={[scaledModifier]}
style={{
width: 'min-content',
height: 'min-content',
cursor: 'none',
userSelect: 'none',
// expand overlay to prevent cursor from going outside it and displaying
padding: '10rem',
}}
>
<AnimatePresence>
{activeDragData && (
<motion.div
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
)}
</AnimatePresence>
</DragOverlay>
</DndContext>
</DndContextTypesafe>
);
};
export default memo(ImageDndContext);
export default memo(AppDndContext);

View File

@ -0,0 +1,6 @@
import { DndContext } from '@dnd-kit/core';
import { DndContextTypesafeProps } from '../types';
export function DndContextTypesafe(props: DndContextTypesafeProps) {
return <DndContext {...props} />;
}

View File

@ -1,6 +1,6 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd';
import { TypesafeDraggableData } from '../types';
type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null;
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
return null;
}
if (props.dragData.payloadType === 'NODE_FIELD') {
const { field, fieldTemplate } = props.dragData.payload;
return (
<Box
sx={{
position: 'relative',
p: 2,
px: 3,
opacity: 0.7,
bg: 'base.300',
borderRadius: 'base',
boxShadow: 'dark-lg',
whiteSpace: 'nowrap',
fontSize: 'sm',
}}
>
<Text>{field.label || fieldTemplate.title}</Text>
</Box>
);
}
if (props.dragData.payloadType === 'IMAGE_DTO') {
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
return (
<Box
sx={{
position: 'relative',
width: '100%',
height: '100%',
width: 'full',
height: 'full',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'none',
}}
>
<Image
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',

View File

@ -0,0 +1,15 @@
import { useDraggable, useDroppable } from '@dnd-kit/core';
import {
UseDraggableTypesafeArguments,
UseDraggableTypesafeReturnValue,
UseDroppableTypesafeArguments,
UseDroppableTypesafeReturnValue,
} from '../types';
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
return useDroppable(props) as UseDroppableTypesafeReturnValue;
}
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
return useDraggable(props) as UseDraggableTypesafeReturnValue;
}

View File

@ -0,0 +1,50 @@
import type { Modifier } from '@dnd-kit/core';
import { getEventCoordinates } from '@dnd-kit/utilities';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
const selectZoom = createSelector(
[stateSelector, activeTabNameSelector],
({ nodes }, activeTabName) => (activeTabName === 'nodes' ? nodes.zoom : 1)
);
/**
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
*/
export const useScaledModifer = () => {
const zoom = useAppSelector(selectZoom);
const modifier: Modifier = useCallback(
({ activatorEvent, draggingNodeRect, transform }) => {
if (draggingNodeRect && activatorEvent) {
const activatorCoordinates = getEventCoordinates(activatorEvent);
if (!activatorCoordinates) {
return transform;
}
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
const x = transform.x + offsetX - draggingNodeRect.width / 2;
const y = transform.y + offsetY - draggingNodeRect.height / 2;
const scaleX = transform.scaleX * zoom;
const scaleY = transform.scaleY * zoom;
return {
x,
y,
scaleX,
scaleY,
};
}
return transform;
},
[zoom]
);
return modifier;
};

View File

@ -3,7 +3,6 @@ import {
Active,
Collision,
DndContextProps,
DndContext as OriginalDndContext,
Over,
Translate,
UseDraggableArguments,
@ -11,6 +10,10 @@ import {
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type AddFieldToLinearViewDropData = BaseDropData & {
actionType: 'ADD_FIELD_TO_LINEAR';
};
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
| AddToBatchDropData
| NodesMultiImageDropData
| AddToBoardDropData
| RemoveFromBoardDropData;
| RemoveFromBoardDropData
| AddFieldToLinearViewDropData;
type BaseDragData = {
id: string;
};
export type NodeFieldDraggableData = BaseDragData & {
payloadType: 'NODE_FIELD';
payload: {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
};
export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO };
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
payload: { imageDTOs: ImageDTO[] };
};
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
export type TypesafeDraggableData =
| NodeFieldDraggableData
| ImageDraggableData
| ImageDTOsDraggableData;
interface UseDroppableTypesafeArguments
export interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
data?: TypesafeDroppableData;
}
type UseDroppableTypesafeReturnValue = Omit<
export type UseDroppableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDroppable>,
'active' | 'over'
> & {
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
export function useDroppable(props: UseDroppableTypesafeArguments) {
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
}
interface UseDraggableTypesafeArguments
export interface UseDraggableTypesafeArguments
extends Omit<UseDraggableArguments, 'data'> {
data?: TypesafeDraggableData;
}
type UseDraggableTypesafeReturnValue = Omit<
export type UseDraggableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDraggable>,
'active' | 'over'
> & {
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
over: TypesafeOver | null;
};
export function useDraggable(props: UseDraggableTypesafeArguments) {
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
}
interface TypesafeActive extends Omit<Active, 'data'> {
export interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}
interface TypesafeOver extends Omit<Over, 'data'> {
export interface TypesafeOver extends Omit<Over, 'data'> {
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
}
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};
interface DragEvent {
activatorEvent: Event;
active: TypesafeActive;
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void;
}
export function DndContext(props: DndContextTypesafeProps) {
return <OriginalDndContext {...props} />;
}

View File

@ -0,0 +1,87 @@
import { TypesafeActive, TypesafeDroppableData } from '../types';
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'ADD_FIELD_TO_LINEAR':
return payloadType === 'NODE_FIELD';
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};

View File

@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import AutoAddIcon from '../AutoAddIcon';
import BoardContextMenu from '../BoardContextMenu';
import { AddToBoardDropData } from 'features/dnd/types';
interface GalleryBoardProps {
board: BoardDTO;

View File

@ -1,7 +1,7 @@
import { As, Badge, Flex } from '@chakra-ui/react';
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types';
import { ReactNode } from 'react';
import BoardContextMenu from '../BoardContextMenu';

View File

@ -1,15 +1,15 @@
import { Box, Flex, Image, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { RemoveFromBoardDropData } from 'features/dnd/types';
import {
boardIdSelected,
autoAddBoardIdChanged,
boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';

View File

@ -1,14 +1,14 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { AnimatePresence, motion } from 'framer-motion';

View File

@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
return (
<VStack
layerStyle="first"
sx={{
flexDirection: 'column',
h: 'full',
w: 'full',
borderRadius: 'base',
p: 2,
}}
>
<Box sx={{ w: 'full' }}>

View File

@ -1,9 +1,4 @@
import { Box, Flex } from '@chakra-ui/react';
import {
ImageDTOsDraggableData,
ImageDraggableData,
TypesafeDraggableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
@ -12,6 +7,11 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import {
ImageDTOsDraggableData,
ImageDraggableData,
TypesafeDraggableData,
} from 'features/dnd/types';
interface HoverableImageProps {
imageName: string;

View File

@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'leave',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},

View File

@ -1,26 +1,40 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useMemo } from 'react';
import { FaCopy } from 'react-icons/fa';
import { useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa';
type Props = {
copyTooltip: string;
label: string;
jsonObject: object;
fileName?: string;
};
const ImageMetadataJSON = (props: Props) => {
const { copyTooltip, jsonObject } = props;
const { label, jsonObject, fileName } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
const handleCopy = useCallback(() => {
navigator.clipboard.writeText(jsonString);
}, [jsonString]);
const handleSave = useCallback(() => {
const blob = new Blob([jsonString]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = `${fileName || label}.json`;
document.body.appendChild(a);
a.click();
a.remove();
}, [jsonString, label, fileName]);
return (
<Flex
layerStyle="second"
sx={{
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
flexGrow: 1,
w: 'full',
h: 'full',
@ -36,6 +50,7 @@ const ImageMetadataJSON = (props: Props) => {
bottom: 0,
overflow: 'auto',
p: 4,
fontSize: 'sm',
}}
>
<OverlayScrollbarsComponent
@ -44,7 +59,7 @@ const ImageMetadataJSON = (props: Props) => {
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => {
</OverlayScrollbarsComponent>
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
<Tooltip label={copyTooltip}>
<Tooltip label={`Save ${label} JSON`}>
<IconButton
aria-label={copyTooltip}
aria-label={`Save ${label} JSON`}
icon={<FaSave />}
variant="ghost"
opacity={0.7}
onClick={handleSave}
/>
</Tooltip>
<Tooltip label={`Copy ${label} JSON`}>
<IconButton
aria-label={`Copy ${label} JSON`}
icon={<FaCopy />}
variant="ghost"
onClick={() => navigator.clipboard.writeText(jsonString)}
opacity={0.7}
onClick={handleCopy}
/>
</Tooltip>
</Flex>

View File

@ -10,7 +10,8 @@ import {
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { memo, useMemo } from 'react';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const metadata = currentData?.metadata;
const graph = currentData?.graph;
const tabData = useMemo(() => {
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
if (metadata) {
_tabData.push({
label: 'Core Metadata',
data: metadata,
copyTooltip: 'Copy Core Metadata JSON',
});
}
if (image) {
_tabData.push({
label: 'Image Details',
data: image,
copyTooltip: 'Copy Image Details JSON',
});
}
if (graph) {
_tabData.push({
label: 'Graph',
data: graph,
copyTooltip: 'Copy Graph JSON',
});
}
return _tabData;
}, [metadata, graph, image]);
return (
<Flex
layerStyle="first"
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'baseAlpha.200',
_dark: {
bg: 'blackAlpha.600',
},
borderRadius: 'base',
position: 'absolute',
overflow: 'hidden',
@ -103,32 +71,33 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
{tabData.map((tab) => (
<Tab
key={tab.label}
sx={{
borderTopRadius: 'base',
}}
>
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
{tab.label}
</Text>
</Tab>
))}
<Tab>Core Metadata</Tab>
<Tab>Image Details</Tab>
<Tab>Graph</Tab>
</TabList>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabData.map((tab) => (
<TabPanel
key={tab.label}
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
>
<ImageMetadataJSON
jsonObject={tab.data}
copyTooltip={tab.copyTooltip}
/>
</TabPanel>
))}
<TabPanels>
<TabPanel>
{metadata ? (
<ImageMetadataJSON jsonObject={metadata} label="Core Metadata" />
) : (
<IAINoContentFallback label="No core metadata found" />
)}
</TabPanel>
<TabPanel>
{image ? (
<ImageMetadataJSON jsonObject={image} label="Image Details" />
) : (
<IAINoContentFallback label="No image details found" />
)}
</TabPanel>
<TabPanel>
{graph ? (
<ImageMetadataJSON jsonObject={graph} label="Graph" />
) : (
<IAINoContentFallback label="No graph found" />
)}
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

View File

@ -9,30 +9,40 @@ import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css';
import { AnyInvocationType } from 'services/events/types';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { useBuildNodeData } from '../hooks/useBuildNodeData';
import { nodeAdded } from '../store/nodesSlice';
type NodeTemplate = {
label: string;
value: string;
description: string;
tags: string[];
};
const selector = createSelector(
[stateSelector],
({ nodes }) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
tags: template.tags,
};
});
data.push({
label: 'Progress Image',
value: 'progress_image',
description: 'Displays the progress image in the Node Editor',
value: 'current_image',
description: 'Displays the current image in the Node Editor',
tags: ['progress'],
});
data.push({
label: 'Notes',
value: 'notes',
description: 'Add notes about your workflow',
tags: ['notes'],
});
return { data };
@ -44,7 +54,7 @@ const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const buildInvocation = useBuildInvocation();
const buildInvocation = useBuildNodeData();
const toaster = useAppToaster();
@ -89,11 +99,12 @@ const AddNodeMenu = () => {
filter={(value, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
item.description.toLowerCase().includes(value.toLowerCase().trim()) ||
item.tags.includes(value.toLowerCase().trim())
}
onChange={handleChange}
sx={{
width: '18rem',
width: '24rem',
}}
/>
</Flex>

View File

@ -0,0 +1,61 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { FIELDS, colorTokenToCssVar } from '../types/constants';
const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
nodes;
const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
let className = 'react-flow__custom_connection-path';
if (shouldAnimateEdges) {
className = className.concat(' animated');
}
return {
stroke,
className,
};
});
export const CustomConnectionLine = ({
fromX,
fromY,
fromPosition,
toX,
toY,
toPosition,
}: ConnectionLineComponentProps) => {
const { stroke, className } = useAppSelector(selector);
const pathParams = {
sourceX: fromX,
sourceY: fromY,
sourcePosition: fromPosition,
targetX: toX,
targetY: toY,
targetPosition: toPosition,
};
const [dAttr] = getBezierPath(pathParams);
return (
<g>
<path
fill="none"
stroke={stroke}
strokeWidth={2}
className={className}
d={dAttr}
style={{ opacity: 0.8 }}
/>
</g>
);
};

View File

@ -0,0 +1,183 @@
import { Badge, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useMemo } from 'react';
import {
BaseEdge,
EdgeLabelRenderer,
EdgeProps,
getBezierPath,
} from 'reactflow';
import { FIELDS, colorTokenToCssVar } from '../types/constants';
import { isInvocationNode } from '../types/types';
const makeEdgeSelector = (
source: string,
sourceHandleId: string | null | undefined,
target: string,
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createSelector(stateSelector, ({ nodes }) => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge =
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
const sourceType = isInvocationToInvocationEdge
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
: undefined;
const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
: colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
const CollapsedEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
data,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[selected, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
const { base500 } = useChakraThemeTokens();
return (
<>
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate
? 'dashdraw 0.5s linear infinite'
: undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
sx={{
position: 'absolute',
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
}}
className="nodrag nopan"
>
<Badge
variant="solid"
sx={{
bg: 'base.500',
opacity: isSelected ? 0.8 : 0.5,
boxShadow: 'base',
}}
>
{data.count}
</Badge>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
const DefaultEdge = ({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
markerEnd,
selected,
source,
target,
sourceHandleId,
targetHandleId,
}: EdgeProps) => {
const selector = useMemo(
() =>
makeEdgeSelector(
source,
sourceHandleId,
target,
targetHandleId,
selected
),
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const [edgePath] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
});
return (
<BaseEdge
path={edgePath}
markerEnd={markerEnd}
style={{
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}}
/>
);
};
export const edgeTypes = {
collapsed: CollapsedEdge,
default: DefaultEdge,
};

View File

@ -0,0 +1,9 @@
import InvocationNode from './nodes/InvocationNode';
import CurrentImageNode from './nodes/CurrentImageNode';
import NotesNode from './nodes/NotesNode';
export const nodeTypes = {
invocation: InvocationNode,
current_image: CurrentImageNode,
notes: NotesNode,
};

View File

@ -1,64 +0,0 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo } from 'react';
import { Handle, Position, Connection, HandleType } from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1rem',
};
const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
// const requiredConnectionStyles: CSSProperties = {
// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
// };
type FieldHandleProps = {
nodeId: string;
field: InputFieldTemplate | OutputFieldTemplate;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
const FieldHandle = (props: FieldHandleProps) => {
const { field, isValidConnection, handleType, styles } = props;
const { name, type } = field;
return (
<Tooltip
label={type}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: FIELDS[type].colorCssVar,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};
export default memo(FieldHandle);

View File

@ -1,8 +1,8 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, Flex } from '@chakra-ui/react';
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
import { map } from 'lodash-es';
import { FIELDS } from '../types/constants';
import { memo } from 'react';
import 'reactflow/dist/style.css';
import { FIELDS } from '../types/constants';
const FieldTypeLegend = () => {
return (
@ -10,8 +10,14 @@ const FieldTypeLegend = () => {
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge
colorScheme={color}
sx={{ userSelect: 'none' }}
sx={{
userSelect: 'none',
color:
parseInt(color.split('.')[1] ?? '0', 10) < 500
? 'base.800'
: 'base.50',
bg: color,
}}
textAlign="center"
>
{title}

View File

@ -1,4 +1,3 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import {
@ -7,35 +6,49 @@ import {
OnConnectEnd,
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnInit,
OnMove,
OnNodesChange,
OnNodesDelete,
OnSelectionChangeFunc,
ProOptions,
ReactFlow,
} from 'reactflow';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
edgesDeleted,
nodesChanged,
setEditorInstance,
nodesDeleted,
selectedEdgesChanged,
selectedNodesChanged,
zoomChanged,
} from '../store/nodesSlice';
import { InvocationComponent } from './InvocationComponent';
import ProgressImageNode from './ProgressImageNode';
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
import MinimapPanel from './panels/MinimapPanel';
import TopCenterPanel from './panels/TopCenterPanel';
import TopLeftPanel from './panels/TopLeftPanel';
import TopRightPanel from './panels/TopRightPanel';
import { CustomConnectionLine } from './CustomConnectionLine';
import { edgeTypes } from './CustomEdges';
import { nodeTypes } from './CustomNodes';
import BottomLeftPanel from './editorPanels/BottomLeftPanel';
import MinimapPanel from './editorPanels/MinimapPanel';
import TopCenterPanel from './editorPanels/TopCenterPanel';
import TopLeftPanel from './editorPanels/TopLeftPanel';
import TopRightPanel from './editorPanels/TopRightPanel';
const nodeTypes = {
invocation: InvocationComponent,
progress_image: ProgressImageNode,
};
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
const proOptions: ProOptions = { hideAttribution: true };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const nodes = useAppSelector((state) => state.nodes.nodes);
const edges = useAppSelector((state) => state.nodes.edges);
const shouldSnapToGrid = useAppSelector(
(state) => state.nodes.shouldSnapToGrid
);
const isValidConnection = useIsValidConnection();
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
@ -69,10 +82,36 @@ export const Flow = () => {
dispatch(connectionEnded());
}, [dispatch]);
const onInit: OnInit = useCallback(
(v) => {
dispatch(setEditorInstance(v));
if (v) v.fitView();
const onInit: OnInit = useCallback((v) => {
v.fitView();
}, []);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
},
[dispatch]
);
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
({ nodes, edges }) => {
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
},
[dispatch]
);
const handleMove: OnMove = useCallback(
(e, viewport) => {
const { zoom } = viewport;
dispatch(zoomChanged(zoom));
},
[dispatch]
);
@ -80,24 +119,33 @@ export const Flow = () => {
return (
<ReactFlow
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
onMove={handleMove}
connectionLineComponent={CustomConnectionLine}
onSelectionChange={handleSelectionChange}
onInit={onInit}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
isValidConnection={isValidConnection}
minZoom={0.2}
snapToGrid={shouldSnapToGrid}
snapGrid={[25, 25]}
connectionRadius={30}
proOptions={proOptions}
>
<TopLeftPanel />
<TopCenterPanel />
<TopRightPanel />
<BottomLeftPanel />
<Background />
<MinimapPanel />
<Background />
</ReactFlow>
);
};

View File

@ -1,55 +0,0 @@
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
import { memo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
interface IAINodeHeaderProps {
nodeId?: string;
title?: string;
description?: string;
}
const IAINodeHeader = (props: IAINodeHeaderProps) => {
const { nodeId, title, description } = props;
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
borderTopRadius: 'md',
alignItems: 'center',
justifyContent: 'space-between',
px: 2,
py: 1,
bg: 'base.100',
_dark: { bg: 'base.900' },
}}
>
<Tooltip label={nodeId}>
<Heading
size="xs"
sx={{
fontWeight: 600,
color: 'base.900',
_dark: { color: 'base.200' },
}}
>
{title}
</Heading>
</Tooltip>
<Tooltip label={description} placement="top" hasArrow shouldWrapChildren>
<Icon
sx={{
h: 'min-content',
color: 'base.700',
_dark: {
color: 'base.300',
},
}}
as={FaInfoCircle}
/>
</Tooltip>
</Flex>
);
};
export default memo(IAINodeHeader);

View File

@ -1,149 +0,0 @@
import {
Box,
Divider,
Flex,
FormControl,
FormLabel,
HStack,
Tooltip,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { ReactNode, memo, useCallback } from 'react';
import FieldHandle from '../FieldHandle';
import InputFieldComponent from '../InputFieldComponent';
interface IAINodeInputProps {
nodeId: string;
input: InputFieldValue;
template?: InputFieldTemplate | undefined;
connected: boolean;
}
function IAINodeInput(props: IAINodeInputProps) {
const { nodeId, input, template, connected } = props;
const isValidConnection = useIsValidConnection();
return (
<Box
className="nopan"
position="relative"
borderColor={
!template
? 'error.400'
: !connected &&
['always', 'connectionOnly'].includes(
String(template?.inputRequirement)
) &&
input.value === undefined
? 'warning.400'
: undefined
}
>
<FormControl isDisabled={!template ? true : connected} pl={2}>
{!template ? (
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown input: {input.name}</FormLabel>
</HStack>
) : (
<>
<HStack justifyContent="space-between" alignItems="center">
<HStack>
<Tooltip
label={template?.description}
placement="top"
hasArrow
shouldWrapChildren
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<FormLabel>{template?.title}</FormLabel>
</Tooltip>
</HStack>
<InputFieldComponent
nodeId={nodeId}
field={input}
template={template}
/>
</HStack>
{!['never', 'directOnly'].includes(
template?.inputRequirement ?? ''
) && (
<FieldHandle
nodeId={nodeId}
field={template}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</>
)}
</FormControl>
</Box>
);
}
interface IAINodeInputsProps {
nodeId: string;
template: InvocationTemplate;
inputs: Record<string, InputFieldValue>;
}
const IAINodeInputs = (props: IAINodeInputsProps) => {
const { nodeId, template, inputs } = props;
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const renderIAINodeInputs = useCallback(() => {
const IAINodeInputsToRender: ReactNode[] = [];
const inputSockets = map(inputs);
inputSockets.forEach((inputSocket, index) => {
const inputTemplate = template.inputs[inputSocket.name];
const isConnected = Boolean(
edges.filter((connectedInput) => {
return (
connectedInput.target === nodeId &&
connectedInput.targetHandle === inputSocket.name
);
}).length
);
if (index < inputSockets.length) {
IAINodeInputsToRender.push(
<Divider key={`${inputSocket.id}.divider`} />
);
}
IAINodeInputsToRender.push(
<IAINodeInput
key={inputSocket.id}
nodeId={nodeId}
input={inputSocket}
template={inputTemplate}
connected={isConnected}
/>
);
});
return (
<Flex className="nopan" flexDir="column" gap={2} p={2}>
{IAINodeInputsToRender}
</Flex>
);
}, [edges, inputs, nodeId, template.inputs]);
return renderIAINodeInputs();
};
export default memo(IAINodeInputs);

View File

@ -1,97 +0,0 @@
import {
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from 'features/nodes/types/types';
import { memo, ReactNode, useCallback } from 'react';
import { map } from 'lodash-es';
import { useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
import FieldHandle from '../FieldHandle';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
interface IAINodeOutputProps {
nodeId: string;
output: OutputFieldValue;
template?: OutputFieldTemplate | undefined;
connected: boolean;
}
function IAINodeOutput(props: IAINodeOutputProps) {
const { nodeId, output, template, connected } = props;
const isValidConnection = useIsValidConnection();
return (
<Box position="relative">
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
{!template ? (
<HStack justifyContent="space-between" alignItems="center">
<FormLabel color="error.400">
Unknown Output: {output.name}
</FormLabel>
</HStack>
) : (
<>
<FormLabel textAlign="end" padding={1}>
{template?.title}
</FormLabel>
<FieldHandle
key={output.id}
nodeId={nodeId}
field={template}
isValidConnection={isValidConnection}
handleType="source"
/>
</>
)}
</FormControl>
</Box>
);
}
interface IAINodeOutputsProps {
nodeId: string;
template: InvocationTemplate;
outputs: Record<string, OutputFieldValue>;
}
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
const { nodeId, template, outputs } = props;
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const renderIAINodeOutputs = useCallback(() => {
const IAINodeOutputsToRender: ReactNode[] = [];
const outputSockets = map(outputs);
outputSockets.forEach((outputSocket) => {
const outputTemplate = template.outputs[outputSocket.name];
const isConnected = Boolean(
edges.filter((connectedInput) => {
return (
connectedInput.source === nodeId &&
connectedInput.sourceHandle === outputSocket.name
);
}).length
);
IAINodeOutputsToRender.push(
<IAINodeOutput
key={outputSocket.id}
nodeId={nodeId}
output={outputSocket}
template={outputTemplate}
connected={isConnected}
/>
);
});
return <Flex flexDir="column">{IAINodeOutputsToRender}</Flex>;
}, [edges, nodeId, outputs, template.outputs]);
return renderIAINodeOutputs();
};
export default memo(IAINodeOutputs);

View File

@ -1,252 +0,0 @@
import { Box } from '@chakra-ui/react';
import { memo } from 'react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputFieldValue;
template: InputFieldTemplate;
};
// build an individual input element based on the schema
const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field, template } = props;
const { type } = field;
if (type === 'string' && template.type === 'string') {
return (
<StringInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'boolean' && template.type === 'boolean') {
return (
<BooleanInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (
(type === 'integer' && template.type === 'integer') ||
(type === 'float' && template.type === 'float')
) {
return (
<NumberInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'enum' && template.type === 'enum') {
return (
<EnumInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image' && template.type === 'image') {
return (
<ImageInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'latents' && template.type === 'latents') {
return (
<LatentsInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'conditioning' && template.type === 'conditioning') {
return (
<ConditioningInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'unet' && template.type === 'unet') {
return (
<UnetInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'clip' && template.type === 'clip') {
return (
<ClipInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae' && template.type === 'vae') {
return (
<VaeInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'control' && template.type === 'control') {
return (
<ControlInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') {
return (
<ModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'lora_model' && template.type === 'lora_model') {
return (
<LoRAModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'color' && template.type === 'color') {
return (
<ColorInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image_collection' && template.type === 'image_collection') {
return (
<ImageCollectionInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};
export default memo(InputFieldComponent);

View File

@ -0,0 +1,57 @@
import { ChevronUpIcon } from '@chakra-ui/icons';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
import { NodeData } from 'features/nodes/types/types';
import { memo, useCallback } from 'react';
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
interface Props {
nodeProps: NodeProps<NodeData>;
}
const NodeCollapseButton = (props: Props) => {
const { id: nodeId, isOpen } = props.nodeProps.data;
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();
const handleClick = useCallback(() => {
dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen }));
updateNodeInternals(nodeId);
}, [dispatch, isOpen, nodeId, updateNodeInternals]);
return (
<IAIIconButton
className="nopan"
onClick={handleClick}
aria-label="Minimize"
sx={{
minW: 8,
w: 8,
h: 8,
color: 'base.500',
_dark: {
color: 'base.500',
},
_hover: {
color: 'base.700',
_dark: {
color: 'base.300',
},
},
}}
variant="link"
icon={
<ChevronUpIcon
sx={{
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
}
/>
);
};
export default memo(NodeCollapseButton);

View File

@ -0,0 +1,74 @@
import { useColorModeValue } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeCollapsedHandles = (props: Props) => {
const { data } = props.nodeProps;
const { base400, base600 } = useChakraThemeTokens();
const backgroundColor = useColorModeValue(base400, base600);
const dummyHandleStyles: CSSProperties = useMemo(
() => ({
borderWidth: 0,
borderRadius: '3px',
width: '1rem',
height: '1rem',
backgroundColor,
zIndex: -1,
}),
[backgroundColor]
);
return (
<>
<Handle
type="target"
id={`${data.id}-collapsed-target`}
isConnectable={false}
position={Position.Left}
style={{ ...dummyHandleStyles, left: '-0.5rem' }}
/>
{map(data.inputs, (input) => (
<Handle
key={`${data.id}-${input.name}-collapsed-input-handle`}
type="target"
id={input.name}
isValidConnection={() => false}
position={Position.Left}
style={{ visibility: 'hidden' }}
/>
))}
<Handle
type="source"
id={`${data.id}-collapsed-source`}
isValidConnection={() => false}
isConnectable={false}
position={Position.Right}
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
/>
{map(data.outputs, (output) => (
<Handle
key={`${data.id}-${output.name}-collapsed-output-handle`}
type="source"
id={output.name}
isValidConnection={() => false}
position={Position.Right}
style={{ visibility: 'hidden' }}
/>
))}
</>
);
};
export default memo(NodeCollapsedHandles);

View File

@ -0,0 +1,77 @@
import {
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
import { NodeProps } from 'reactflow';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
};
const NodeFooter = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const dispatch = useAppDispatch();
const hasImageOutput = useMemo(
() =>
some(nodeTemplate?.outputs, (output) =>
['ImageField', 'ImageCollection'].includes(output.type)
),
[nodeTemplate?.outputs]
);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: nodeProps.data.id,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeProps.data.id]
);
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
layerStyle="nodeFooter"
sx={{
w: 'full',
borderBottomRadius: 'base',
px: 2,
py: 0,
h: 6,
}}
>
<Spacer />
{hasImageOutput && (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
/>
</FormControl>
)}
</Flex>
);
};
export default memo(NodeFooter);

View File

@ -0,0 +1,113 @@
import {
Flex,
FormControl,
FormLabel,
Icon,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
InvocationTemplate,
} from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
interface Props {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
}
const NodeNotesEdit = (props: Props) => {
const { nodeProps, nodeTemplate } = props;
const { data } = nodeProps;
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
},
[data.id, dispatch]
);
return (
<>
<Tooltip
label={
nodeTemplate ? (
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
) : undefined
}
placement="top"
shouldWrapChildren
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
onClick={onOpen}
sx={{
alignItems: 'center',
justifyContent: 'center',
w: 8,
h: 8,
cursor: 'pointer',
}}
>
<Icon
as={FaInfoCircle}
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
/>
</Flex>
</Tooltip>
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>
{data.label || nodeTemplate?.title || 'Unknown Node'}
</ModalHeader>
<ModalCloseButton />
<ModalBody>
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
};
export default memo(NodeNotesEdit);
type TooltipContentProps = Props;
const TooltipContent = (props: TooltipContentProps) => {
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{props.nodeTemplate?.description}
</Text>
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
</Flex>
);
};

View File

@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
import { memo } from 'react';
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
const IAINodeResizer = (props: NodeResizerProps) => {
// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
// not using it for now
const NodeResizer = (props: NodeResizerProps) => {
const { ...rest } = props;
return (
<NodeResizeControl
@ -21,4 +24,4 @@ const IAINodeResizer = (props: NodeResizerProps) => {
);
};
export default memo(IAINodeResizer);
export default memo(NodeResizer);

View File

@ -0,0 +1,69 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import IAISwitch from 'common/components/IAISwitch';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { InvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaBars } from 'react-icons/fa';
interface Props {
data: InvocationNodeData;
}
const NodeSettings = (props: Props) => {
const { data } = props;
const dispatch = useAppDispatch();
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId: data.id,
fieldName: 'is_intermediate',
value: e.target.checked,
})
);
},
[data.id, dispatch]
);
return (
<IAIPopover
isLazy={false}
triggerComponent={
<IAIIconButton
className="nopan"
aria-label="Node Settings"
variant="link"
sx={{
minW: 8,
color: 'base.500',
_dark: {
color: 'base.500',
},
_hover: {
color: 'base.700',
_dark: {
color: 'base.300',
},
},
}}
icon={<FaBars />}
/>
}
>
<Flex sx={{ flexDir: 'column', gap: 4, w: 64 }}>
<IAISwitch
label="Intermediate"
isChecked={Boolean(data.inputs['is_intermediate']?.value)}
onChange={handleChangeIsIntermediate}
helperText="The outputs of intermediate nodes are considered temporary objects. Intermediate images are not added to the gallery."
/>
</Flex>
</IAIPopover>
);
};
export default memo(NodeSettings);

View File

@ -0,0 +1,185 @@
import {
Badge,
CircularProgress,
Flex,
Icon,
Image,
Text,
Tooltip,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
InvocationNodeData,
NodeExecutionState,
NodeStatus,
} from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
type Props = {
nodeProps: NodeProps<InvocationNodeData>;
};
const iconBoxSize = 3;
const circleStyles = {
circle: {
transitionProperty: 'none',
transitionDuration: '0s',
},
'.chakra-progress__track': { stroke: 'transparent' },
};
const NodeStatusIndicator = (props: Props) => {
const nodeId = props.nodeProps.data.id;
const selectNodeExecutionState = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => nodes.nodeExecutionStates[nodeId]
),
[nodeId]
);
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
if (!nodeExecutionState) {
return null;
}
return (
<Tooltip
label={<TooltipLabel nodeExecutionState={nodeExecutionState} />}
placement="top"
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
w: 5,
h: 'full',
alignItems: 'center',
justifyContent: 'flex-end',
}}
>
<StatusIcon nodeExecutionState={nodeExecutionState} />
</Flex>
</Tooltip>
);
};
export default memo(NodeStatusIndicator);
type TooltipLabelProps = {
nodeExecutionState: NodeExecutionState;
};
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
const { status, progress, progressImage } = nodeExecutionState;
if (status === NodeStatus.PENDING) {
return <Text>Pending</Text>;
}
if (status === NodeStatus.IN_PROGRESS) {
if (progressImage) {
return (
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
<Image
src={progressImage.dataURL}
sx={{ w: 32, h: 32, borderRadius: 'base', objectFit: 'contain' }}
/>
{progress !== null && (
<Badge
variant="solid"
sx={{ pos: 'absolute', top: 2.5, insetInlineEnd: 1 }}
>
{Math.round(progress * 100)}%
</Badge>
)}
</Flex>
);
}
if (progress !== null) {
return <Text>In Progress ({Math.round(progress * 100)}%)</Text>;
}
return <Text>In Progress</Text>;
}
if (status === NodeStatus.COMPLETED) {
return <Text>Completed</Text>;
}
if (status === NodeStatus.FAILED) {
return <Text>nodeExecutionState.error</Text>;
}
return null;
};
type StatusIconProps = {
nodeExecutionState: NodeExecutionState;
};
const StatusIcon = (props: StatusIconProps) => {
const { progress, status } = props.nodeExecutionState;
if (status === NodeStatus.PENDING) {
return (
<Icon
as={FaEllipsisH}
sx={{
boxSize: iconBoxSize,
color: 'base.600',
_dark: { color: 'base.300' },
}}
/>
);
}
if (status === NodeStatus.IN_PROGRESS) {
return progress === null ? (
<CircularProgress
isIndeterminate
size="14px"
color="base.500"
thickness={14}
sx={circleStyles}
/>
) : (
<CircularProgress
value={Math.round(progress * 100)}
size="14px"
color="base.500"
thickness={14}
sx={circleStyles}
/>
);
}
if (status === NodeStatus.COMPLETED) {
return (
<Icon
as={FaCheck}
sx={{
boxSize: iconBoxSize,
color: 'ok.600',
_dark: { color: 'ok.300' },
}}
/>
);
}
if (status === NodeStatus.FAILED) {
return (
<Icon
as={FaExclamation}
sx={{
boxSize: iconBoxSize,
color: 'error.600',
_dark: { color: 'error.300' },
}}
/>
);
}
return null;
};

View File

@ -0,0 +1,123 @@
import {
Box,
Editable,
EditableInput,
EditablePreview,
Flex,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { NodeData } from 'features/nodes/types/types';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
type Props = {
nodeData: NodeData;
title: string;
};
const NodeTitle = (props: Props) => {
const { title } = props;
const { id: nodeId, label } = props.nodeData;
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(label || title);
const handleSubmit = useCallback(
async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title);
},
[nodeId, dispatch, title]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || title);
}, [label, title]);
return (
<Flex
className="nopan"
sx={{
overflow: 'hidden',
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
cursor: 'text',
}}
>
<Editable
as={Flex}
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
sx={{
alignItems: 'center',
position: 'relative',
w: 'full',
h: 'full',
}}
>
<EditablePreview
fontSize="sm"
sx={{
p: 0,
w: 'full',
}}
noOfLines={1}
/>
<EditableInput
fontSize="sm"
sx={{
p: 0,
_focusVisible: {
p: 0,
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex>
);
};
export default memo(NodeTitle);
function EditableControls() {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Box
className={DRAG_HANDLE_CLASSNAME}
onDoubleClick={handleDoubleClick}
sx={{
position: 'absolute',
w: 'full',
h: 'full',
top: 0,
cursor: 'grab',
}}
/>
);
}

View File

@ -0,0 +1,96 @@
import {
Box,
ChakraProps,
useColorModeValue,
useToken,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeClicked } from 'features/nodes/store/nodesSlice';
import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants';
import { NodeData } from 'features/nodes/types/types';
import { NodeProps } from 'reactflow';
const useNodeSelect = (nodeId: string) => {
const dispatch = useAppDispatch();
const selectNode = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey }));
},
[dispatch, nodeId]
);
return selectNode;
};
type NodeWrapperProps = PropsWithChildren & {
nodeProps: NodeProps<NodeData>;
width?: NonNullable<ChakraProps['sx']>['w'];
};
const NodeWrapper = (props: NodeWrapperProps) => {
const { width, children, nodeProps } = props;
const { data, selected } = nodeProps;
const nodeId = data.id;
const [
nodeSelectedOutlineLight,
nodeSelectedOutlineDark,
shadowsXl,
shadowsBase,
] = useToken('shadows', [
'nodeSelectedOutline.light',
'nodeSelectedOutline.dark',
'shadows.xl',
'shadows.base',
]);
const selectNode = useNodeSelect(nodeId);
const shadow = useColorModeValue(
nodeSelectedOutlineLight,
nodeSelectedOutlineDark
);
const shift = useAppSelector((state) => state.hotkeys.shift);
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
const className = useMemo(
() => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'),
[shift]
);
return (
<Box
onClickCapture={selectNode}
className={className}
sx={{
h: 'full',
position: 'relative',
borderRadius: 'base',
w: width ?? NODE_WIDTH,
transitionProperty: 'common',
transitionDuration: '0.1s',
shadow: selected ? shadow : undefined,
opacity,
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
pointerEvents: 'none',
shadow: `${shadowsXl}, ${shadowsBase}, ${shadowsBase}`,
zIndex: -1,
}}
/>
{children}
</Box>
);
};
export default NodeWrapper;

View File

@ -1,74 +0,0 @@
import { Flex, Icon } from '@chakra-ui/react';
import { FaExclamationCircle } from 'react-icons/fa';
import { NodeProps } from 'reactflow';
import { InvocationValue } from '../types/types';
import { useAppSelector } from 'app/store/storeHooks';
import { memo, useMemo } from 'react';
import { makeTemplateSelector } from '../store/util/makeTemplateSelector';
import IAINodeHeader from './IAINode/IAINodeHeader';
import IAINodeInputs from './IAINode/IAINodeInputs';
import IAINodeOutputs from './IAINode/IAINodeOutputs';
import IAINodeResizer from './IAINode/IAINodeResizer';
import NodeWrapper from './NodeWrapper';
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
const { id: nodeId, data, selected } = props;
const { type, inputs, outputs } = data;
const templateSelector = useMemo(() => makeTemplateSelector(type), [type]);
const template = useAppSelector(templateSelector);
if (!template) {
return (
<NodeWrapper selected={selected}>
<Flex
className="nopan"
sx={{
alignItems: 'center',
justifyContent: 'center',
cursor: 'auto',
}}
>
<Icon
as={FaExclamationCircle}
sx={{
boxSize: 32,
color: 'base.600',
_dark: { color: 'base.400' },
}}
></Icon>
<IAINodeResizer />
</Flex>
</NodeWrapper>
);
}
return (
<NodeWrapper selected={selected}>
<IAINodeHeader
nodeId={nodeId}
title={template.title}
description={template.description}
/>
<Flex
className={'nopan'}
sx={{
cursor: 'auto',
flexDirection: 'column',
borderBottomRadius: 'md',
py: 2,
bg: 'base.150',
_dark: { bg: 'base.800' },
}}
>
<IAINodeOutputs nodeId={nodeId} outputs={outputs} template={template} />
<IAINodeInputs nodeId={nodeId} inputs={inputs} template={template} />
</Flex>
<IAINodeResizer />
</NodeWrapper>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@ -1,25 +1,45 @@
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo, useState } from 'react';
import { Panel, PanelGroup } from 'react-resizable-panels';
import 'reactflow/dist/style.css';
import { memo } from 'react';
import { Flow } from './Flow';
import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup';
const NodeEditor = () => {
const [isPanelCollapsed, setIsPanelCollapsed] = useState(false);
return (
<Box
layerStyle={'first'}
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'base',
}}
<PanelGroup
id="node-editor"
autoSaveId="node-editor"
direction="horizontal"
style={{ height: '100%', width: '100%' }}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
</Box>
<Panel
id="node-editor-panel-group"
collapsible
onCollapse={setIsPanelCollapsed}
minSize={25}
>
<NodeEditorPanelGroup />
</Panel>
<ResizeHandle
collapsedDirection={isPanelCollapsed ? 'left' : undefined}
/>
<Panel id="node-editor-content">
<Box
layerStyle={'first'}
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'base',
}}
>
<Flow />
</Box>
</Panel>
</PanelGroup>
);
};

View File

@ -0,0 +1,139 @@
import {
Divider,
Flex,
Heading,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent, useCallback } from 'react';
import { FaCog } from 'react-icons/fa';
import {
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from '../store/nodesSlice';
const selector = createSelector(stateSelector, ({ nodes }) => {
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
};
});
const NodeEditorSettings = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldValidateGraphChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldAnimate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAnimateEdgesChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldSnap = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldSnapToGridChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldColor = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldColorEdgesChanged(e.target.checked));
},
[dispatch]
);
return (
<>
<IAIIconButton
aria-label="Node Editor Settings"
icon={<FaCog />}
onClick={onOpen}
/>
<Modal isOpen={isOpen} onClose={onClose} size="2xl" isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>Node Editor Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex
sx={{
flexDirection: 'column',
gap: 4,
py: 4,
}}
>
<Heading size="sm">General</Heading>
<IAISwitch
onChange={handleChangeShouldAnimate}
isChecked={shouldAnimateEdges}
label="Animated Edges"
helperText="Animate selected edges and edges connected to selected nodes"
/>
<Divider />
<IAISwitch
isChecked={shouldSnapToGrid}
onChange={handleChangeShouldSnap}
label="Snap to Grid"
helperText="Snap nodes to grid when moved"
/>
<Divider />
<IAISwitch
isChecked={shouldColorEdges}
onChange={handleChangeShouldColor}
label="Color-Code Edges"
helperText="Color-code edges according to their connected fields"
/>
<Heading size="sm" pt={4}>
Advanced
</Heading>
<IAISwitch
isChecked={shouldValidateGraph}
onChange={handleChangeShouldValidate}
label="Validate Connections and Graph"
helperText="Prevent invalid connections from being made, and invalid graphs from being invoked"
/>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
};
export default NodeEditorSettings;

View File

@ -1,34 +1,26 @@
import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON';
import { omit } from 'lodash-es';
import { useMemo } from 'react';
import { useDebounce } from 'use-debounce';
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
const NodeGraphOverlay = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
as="pre"
sx={{
fontFamily: 'monospace',
position: 'absolute',
top: 2,
right: 2,
opacity: 0.7,
p: 2,
maxHeight: 500,
maxWidth: 500,
overflowY: 'scroll',
borderRadius: 'base',
bg: 'base.200',
_dark: { bg: 'base.800' },
}}
>
{JSON.stringify(graph, null, 2)}
</Box>
const useNodesGraph = () => {
const nodes = useAppSelector((state: RootState) => state.nodes);
const [debouncedNodes] = useDebounce(nodes, 300);
const graph = useMemo(
() => omit(buildNodesGraph(debouncedNodes), 'id'),
[debouncedNodes]
);
return graph;
};
export default memo(NodeGraphOverlay);
const NodeGraph = () => {
const graph = useNodesGraph();
return <ImageMetadataJSON jsonObject={graph} label="Graph" />;
};
export default NodeGraph;

View File

@ -0,0 +1,42 @@
import {
Box,
Slider,
SliderFilledTrack,
SliderThumb,
SliderTrack,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { nodeOpacityChanged } from '../store/nodesSlice';
export default function NodeOpacitySlider() {
const dispatch = useAppDispatch();
const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity);
const handleChange = useCallback(
(v: number) => {
dispatch(nodeOpacityChanged(v));
},
[dispatch]
);
return (
<Box>
<Slider
aria-label="Node Opacity"
value={nodeOpacity}
min={0.5}
max={1}
step={0.01}
onChange={handleChange}
orientation="vertical"
defaultValue={30}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Box>
);
}

View File

@ -1,36 +0,0 @@
import { Box, useToken } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { PropsWithChildren } from 'react';
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
import { NODE_MIN_WIDTH } from '../types/constants';
type NodeWrapperProps = PropsWithChildren & {
selected: boolean;
};
const NodeWrapper = (props: NodeWrapperProps) => {
const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [
'nodeSelectedOutline',
'dark-lg',
]);
const shift = useAppSelector((state) => state.hotkeys.shift);
return (
<Box
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
sx={{
position: 'relative',
borderRadius: 'md',
minWidth: NODE_MIN_WIDTH,
shadow: props.selected
? `${nodeSelectedOutline}, ${nodeShadow}`
: `${nodeShadow}`,
}}
>
{props.children}
</Box>
);
};
export default NodeWrapper;

View File

@ -1,73 +0,0 @@
import { Flex, Image } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { NodeProps, OnResize } from 'reactflow';
import { setProgressNodeSize } from '../store/nodesSlice';
import IAINodeHeader from './IAINode/IAINodeHeader';
import IAINodeResizer from './IAINode/IAINodeResizer';
import NodeWrapper from './NodeWrapper';
const ProgressImageNode = (props: NodeProps) => {
const progressImage = useSelector(
(state: RootState) => state.system.progressImage
);
const progressNodeSize = useSelector(
(state: RootState) => state.nodes.progressNodeSize
);
const dispatch = useDispatch();
const { selected } = props;
const handleResize: OnResize = (_, newSize) => {
dispatch(setProgressNodeSize(newSize));
};
return (
<NodeWrapper selected={selected}>
<IAINodeHeader
title="Progress Image"
description="Displays the progress image in the Node Editor"
/>
<Flex
sx={{
flexDirection: 'column',
flexShrink: 0,
borderBottomRadius: 'md',
bg: 'base.200',
_dark: { bg: 'base.800' },
width: progressNodeSize.width - 2,
height: progressNodeSize.height - 2,
minW: 250,
minH: 250,
overflow: 'hidden',
}}
>
{progressImage ? (
<Image
src={progressImage.dataURL}
sx={{
w: 'full',
h: 'full',
objectFit: 'contain',
}}
/>
) : (
<Flex
sx={{
minW: 250,
minH: 250,
width: progressNodeSize.width - 2,
height: progressNodeSize.height - 2,
}}
>
<IAINoContentFallback />
</Flex>
)}
</Flex>
<IAINodeResizer onResize={handleResize} />
</NodeWrapper>
);
};
export default memo(ProgressImageNode);

View File

@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { memo, useCallback } from 'react';
import {
FaCode,
FaExpand,
FaMinus,
FaPlus,
FaInfo,
FaMapMarkerAlt,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import { useTranslation } from 'react-i18next';
import {
shouldShowGraphOverlayChanged,
FaExpand,
FaInfo,
FaMapMarkerAlt,
FaMinus,
FaPlus,
} from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
import {
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
} from '../store/nodesSlice';
@ -22,9 +20,6 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
const shouldShowGraphOverlay = useAppSelector(
(state) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state) => state.nodes.shouldShowFieldTypeLegend
);
@ -44,10 +39,6 @@ const ViewportControls = () => {
fitView();
}, [fitView]);
const handleClickedToggleGraphOverlay = useCallback(() => {
dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay));
}, [shouldShowGraphOverlay, dispatch]);
const handleClickedToggleFieldTypeLegend = useCallback(() => {
dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
}, [shouldShowFieldTypeLegend, dispatch]);
@ -79,20 +70,6 @@ const ViewportControls = () => {
icon={<FaExpand />}
/>
</Tooltip>
<Tooltip
label={
shouldShowGraphOverlay
? t('nodes.hideGraphNodes')
: t('nodes.showGraphNodes')
}
>
<IAIIconButton
aria-label="Toggle nodes graph overlay"
isChecked={shouldShowGraphOverlay}
onClick={handleClickedToggleGraphOverlay}
icon={<FaCode />}
/>
</Tooltip>
<Tooltip
label={
shouldShowFieldTypeLegend

View File

@ -1,10 +1,15 @@
import { memo } from 'react';
import { Panel } from 'reactflow';
import ViewportControls from '../ViewportControls';
import NodeOpacitySlider from '../NodeOpacitySlider';
import { Flex } from '@chakra-ui/react';
const BottomLeftPanel = () => (
<Panel position="bottom-left">
<ViewportControls />
<Flex sx={{ gap: 2 }}>
<ViewportControls />
<NodeOpacitySlider />
</Flex>
</Panel>
);

View File

@ -20,7 +20,7 @@ const MinimapPanel = () => {
const nodeColor = useColorModeValue(
'var(--invokeai-colors-accent-300)',
'var(--invokeai-colors-accent-700)'
'var(--invokeai-colors-accent-600)'
);
const maskColor = useColorModeValue(
@ -32,10 +32,9 @@ const MinimapPanel = () => {
<>
{shouldShowMinimapPanel && (
<MiniMap
nodeStrokeWidth={3}
pannable
zoomable
nodeBorderRadius={30}
nodeBorderRadius={15}
style={miniMapStyle}
nodeColor={nodeColor}
maskColor={maskColor}

View File

@ -2,11 +2,10 @@ import { HStack } from '@chakra-ui/react';
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
import { memo } from 'react';
import { Panel } from 'reactflow';
import NodeEditorSettings from '../NodeEditorSettings';
import ClearGraphButton from '../ui/ClearGraphButton';
import LoadGraphButton from '../ui/LoadGraphButton';
import NodeInvokeButton from '../ui/NodeInvokeButton';
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
import SaveGraphButton from '../ui/SaveGraphButton';
const TopCenterPanel = () => {
return (
@ -15,9 +14,8 @@ const TopCenterPanel = () => {
<NodeInvokeButton />
<CancelButton />
<ReloadSchemaButton />
<SaveGraphButton />
<LoadGraphButton />
<ClearGraphButton />
<NodeEditorSettings />
</HStack>
</Panel>
);

View File

@ -1,22 +1,16 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { Panel } from 'reactflow';
import FieldTypeLegend from '../FieldTypeLegend';
import NodeGraphOverlay from '../NodeGraphOverlay';
const TopRightPanel = () => {
const shouldShowGraphOverlay = useAppSelector(
(state: RootState) => state.nodes.shouldShowGraphOverlay
);
const shouldShowFieldTypeLegend = useAppSelector(
(state: RootState) => state.nodes.shouldShowFieldTypeLegend
(state) => state.nodes.shouldShowFieldTypeLegend
);
return (
<Panel position="top-right">
{shouldShowFieldTypeLegend && <FieldTypeLegend />}
{shouldShowGraphOverlay && <NodeGraphOverlay />}
</Panel>
);
};

View File

@ -1,15 +0,0 @@
import {
ArrayInputFieldTemplate,
ArrayInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';
const ArrayInputFieldComponent = (
_props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
) => {
return <FaList />;
};
export default memo(ArrayInputFieldComponent);

View File

@ -1,37 +0,0 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types/types';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeId, field, template } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{template.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};
export default memo(EnumInputFieldComponent);

View File

@ -0,0 +1,47 @@
import { MenuItem, MenuList } from '@chakra-ui/react';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
import { MouseEvent, useCallback } from 'react';
import { menuListMotionProps } from 'theme/components/menu';
type Props = {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const FieldContextMenu = (props: Props) => {
const skipEvent = useCallback((e: MouseEvent<HTMLDivElement>) => {
e.preventDefault();
}, []);
return (
<ContextMenu<HTMLDivElement>
menuProps={{
size: 'sm',
isLazy: true,
}}
menuButtonProps={{
bg: 'transparent',
_hover: { bg: 'transparent' },
}}
renderMenu={() => (
<MenuList
sx={{ visibility: 'visible !important' }}
motionProps={menuListMotionProps}
onContextMenu={skipEvent}
>
<MenuItem>Test</MenuItem>
</MenuList>
)}
>
{props.children}
</ContextMenu>
);
};
export default FieldContextMenu;

View File

@ -0,0 +1,122 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, NodeProps, Position } from 'reactflow';
import {
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
colorTokenToCssVar,
} from '../../types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
OutputFieldTemplate,
OutputFieldValue,
} from '../../types/types';
export const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
zIndex: 1,
};
export const inputHandleStyles: CSSProperties = {
left: '-1rem',
};
export const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
type FieldHandleProps = {
nodeProps: NodeProps<InvocationNodeData>;
nodeTemplate: InvocationTemplate;
field: InputFieldValue | OutputFieldValue;
fieldTemplate: InputFieldTemplate | OutputFieldTemplate;
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError: string | null;
};
const FieldHandle = (props: FieldHandleProps) => {
const {
fieldTemplate,
handleType,
isConnectionInProgress,
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color, title } = FIELDS[type];
const styles: CSSProperties = useMemo(() => {
const s: CSSProperties = {
backgroundColor: colorTokenToCssVar(color),
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
zIndex: 1,
};
if (handleType === 'target') {
s.insetInlineStart = '-1rem';
} else {
s.insetInlineEnd = '-1rem';
}
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
s.filter = 'opacity(0.4) grayscale(0.7)';
}
if (isConnectionInProgress && connectionError) {
if (isConnectionStartField) {
s.cursor = 'grab';
} else {
s.cursor = 'not-allowed';
}
} else {
s.cursor = 'crosshair';
}
return s;
}, [
color,
connectionError,
handleType,
isConnectionInProgress,
isConnectionStartField,
]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return (
<Tooltip
label={tooltip}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
position={handleType === 'target' ? Position.Left : Position.Right}
style={styles}
/>
</Tooltip>
);
};
export default memo(FieldHandle);

Some files were not shown because too many files have changed in this diff Show More