mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/batch-graphs
This commit is contained in:
commit
314891a125
@ -5,7 +5,7 @@ from PIL import Image
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
@ -19,6 +19,7 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
@ -286,3 +287,41 @@ async def delete_images_from_list(
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
@ -38,7 +38,7 @@ import mimetypes
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
|
||||
import torch
|
||||
@ -134,6 +134,11 @@ def custom_openapi():
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
|
@ -3,15 +3,366 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseConfig, BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class UIType(str, Enum):
|
||||
"""
|
||||
Type hints for the UI.
|
||||
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||
on adding a new field type, which involves client-side changes.
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Integer = "integer"
|
||||
Float = "float"
|
||||
Boolean = "boolean"
|
||||
String = "string"
|
||||
Array = "array"
|
||||
Image = "ImageField"
|
||||
Latents = "LatentsField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Color = "ColorField"
|
||||
ImageCollection = "ImageCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
StringCollection = "StringCollection"
|
||||
BooleanCollection = "BooleanCollection"
|
||||
# endregion
|
||||
|
||||
# region Models
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VaeModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
# endregion
|
||||
|
||||
# region Iterate/Collect
|
||||
Collection = "Collection"
|
||||
CollectionItem = "CollectionItem"
|
||||
# endregion
|
||||
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are \
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class _InputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
||||
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
@ -39,6 +390,20 @@ class BaseInvocationOutput(BaseModel):
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing connections for field {field_name}")
|
||||
|
||||
|
||||
class MissingInputException(Exception):
|
||||
"""Raised when an field which requires some input, but did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
@ -76,70 +441,81 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
# fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
# fmt: on
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||
)
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
"image_collection",
|
||||
"vae_model",
|
||||
"lora_model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
T = TypeVar("T", bound=BaseInvocation)
|
||||
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
||||
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.title = title
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.tags = list(tags)
|
||||
return cls
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
return wrapper
|
||||
|
@ -3,58 +3,25 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of images"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
@tags("collection", "integer", "range")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||
}
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
stop: int = InputField(default=10, description="The stop of the range")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
@ -62,76 +29,44 @@ class RangeInvocation(BaseInvocation):
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
@title("Integer Range of Size")
|
||||
@tags("range", "integer", "size", "collection")
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
size: int = InputField(default=1, description="The number of values")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
|
||||
|
||||
@title("Random Range")
|
||||
@tags("range", "integer", "random", "collection")
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = InputField(default=1, description="The number of values to generate")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""Load a collection of images and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
images: list[ImageField] = Field(
|
||||
default=[], description="The image collection to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
"images": "image_collection",
|
||||
}
|
||||
},
|
||||
}
|
||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
@ -1,32 +1,35 @@
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management import ModelPatcher, ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -41,32 +44,26 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("Compel Prompt")
|
||||
@tags("prompt", "compel")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
@ -149,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@ -270,30 +267,26 @@ class SDXLPromptInvocationBase:
|
||||
return c, c_pooled, ec
|
||||
|
||||
|
||||
@title("SDXL Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
@ -326,38 +319,32 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
style: str = InputField(
|
||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||
) # TODO: ?
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
@ -380,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@ -391,21 +378,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@title("CLIP Skip")
|
||||
@tags("clipskip", "clip", "skip")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||
}
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
|
@ -26,79 +26,31 @@ from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ..models.image import ImageOutput, PILInvocationConfig
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
tuple(
|
||||
[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
)
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
|
||||
|
||||
@ -110,9 +62,8 @@ class ControlNetModelField(BaseModel):
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -135,60 +86,39 @@ class ControlField(BaseModel):
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@title("ControlNet")
|
||||
@tags("controlnet")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
@ -204,19 +134,13 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
@ -255,20 +179,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Canny Processor")
|
||||
@tags("controlnet", "canny")
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||
}
|
||||
# Input
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
@ -276,23 +200,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("HED (softedge) Processor")
|
||||
@tags("controlnet", "hed", "softedge")
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -307,21 +227,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Lineart Processor")
|
||||
@tags("controlnet", "lineart")
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -331,23 +247,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Lineart Anime Processor")
|
||||
@tags("controlnet", "lineart", "anime")
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -359,21 +268,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Openpose Processor")
|
||||
@tags("controlnet", "openpose", "pose")
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -386,22 +291,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Midas (Depth) Processor")
|
||||
@tags("controlnet", "midas", "depth")
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -415,20 +316,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Normal BAE Processor")
|
||||
@tags("controlnet", "normal", "bae")
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -438,22 +335,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("MLSD Processor")
|
||||
@tags("controlnet", "mlsd")
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -467,22 +360,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("PIDI Processor")
|
||||
@tags("controlnet", "pidi")
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -496,26 +385,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Content Shuffle Processor")
|
||||
@tags("controlnet", "contentshuffle")
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
@ -531,17 +413,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Zoe (Depth) Processor")
|
||||
@tags("controlnet", "zoe", "depth")
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -549,20 +426,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Mediapipe Face Processor")
|
||||
@tags("controlnet", "mediapipe", "face")
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
@ -574,23 +447,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Leres (Depth) Processor")
|
||||
@tags("controlnet", "leres", "depth")
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
# Inputs
|
||||
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = Field(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@ -605,21 +474,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
return processed_image
|
||||
|
||||
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
# fmt: off
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
# Inputs
|
||||
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
# fmt: on
|
||||
@title("Tile Resample Processor")
|
||||
@tags("controlnet", "tile")
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||
},
|
||||
}
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
|
||||
# Inputs
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
@ -648,20 +512,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Segment Anything Processor")
|
||||
@tags("controlnet", "segmentanything")
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Segment Anything Processor",
|
||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
|
@ -5,40 +5,22 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
@title("OpenCV Inpaint")
|
||||
@tags("opencv", "inpaint")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
@ -1,60 +1,31 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Show Image")
|
||||
@tags("image")
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -70,24 +41,20 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Crop Image")
|
||||
@tags("image", "crop")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -111,24 +78,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Paste Image")
|
||||
@tags("image", "paste")
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||
}
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
image: ImageField = InputField(description="The image to paste")
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
@ -164,23 +130,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Mask from Alpha")
|
||||
@tags("image", "mask")
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
@ -196,28 +158,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Multiply Images")
|
||||
@tags("image", "multiply")
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||
}
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
image2: ImageField = InputField(description="The second image to multiply")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
@ -244,21 +202,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Extract Image Channel")
|
||||
@tags("image", "channel")
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -284,21 +238,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Convert Image Mode")
|
||||
@tags("image", "convert")
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -321,22 +271,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Blur Image")
|
||||
@tags("image", "blur")
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||
# Metadata
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -382,23 +329,19 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Resize Image")
|
||||
@tags("image", "resize")
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
||||
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -426,22 +369,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Scale Image")
|
||||
@tags("image", "scale")
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
scale_factor: float = InputField(
|
||||
default=2.0,
|
||||
gt=0,
|
||||
description="The factor by which to scale the image",
|
||||
)
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -471,22 +414,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Lerp Image")
|
||||
@tags("image", "lerp")
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -512,25 +451,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Inverse Lerp Image")
|
||||
@tags("image", "ilerp")
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -556,21 +488,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Blur NSFW Image")
|
||||
@tags("image", "nsfw")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_nsfw"] = "img_nsfw"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -607,22 +537,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Add Invisible Watermark")
|
||||
@tags("image", "watermark")
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_watermark"] = "img_watermark"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -644,21 +572,23 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Mask Edge")
|
||||
@tags("image", "mask", "inpaint")
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask_edge"] = "mask_edge"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
|
||||
edge_size: int = Field(description="The size of the edge")
|
||||
edge_blur: int = Field(description="The amount of blur on the edge")
|
||||
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
edge_size: int = InputField(description="The size of the edge")
|
||||
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||
high_threshold: int = InputField(
|
||||
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
@ -683,28 +613,23 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Combine Mask")
|
||||
@tags("image", "mask", "multiply")
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask_combine"] = "mask_combine"
|
||||
|
||||
# Inputs
|
||||
mask1: ImageField = Field(default=None, description="The first mask to combine")
|
||||
mask2: ImageField = Field(default=None, description="The second image to combine")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
|
||||
}
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
mask2: ImageField = InputField(description="The second image to combine")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||
@ -728,7 +653,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Color Correct")
|
||||
@tags("image", "color")
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
@ -736,10 +663,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
type: Literal["color_correct"] = "color_correct"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
|
||||
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
|
||||
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to color-correct")
|
||||
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_init_mask = None
|
||||
@ -833,16 +761,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
@title("Image Hue Adjustment")
|
||||
@tags("image", "hue", "hsl")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -877,16 +805,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Image Luminosity Adjustment")
|
||||
@tags("image", "luminosity", "hsl")
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
luminosity: float = InputField(
|
||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -925,16 +855,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Image Saturation Adjustment")
|
||||
@tags("image", "saturation", "hsl")
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to adjust")
|
||||
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@title("Solid Color Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: ColorField = Field(
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Tile Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
# Input
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("PatchMatch Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
@ -13,16 +13,25 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
@ -32,48 +41,27 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
seed: Optional[int] = Field(description="Seed used to generate this latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@ -111,30 +99,36 @@ def get_scheduler(
|
||||
return scheduler
|
||||
|
||||
|
||||
@title("Denoise Latents")
|
||||
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
type: Literal["denoise_latents"] = "denoise_latents"
|
||||
|
||||
# Inputs
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
mask: Optional[ImageField] = Field(
|
||||
None,
|
||||
description="Mask",
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -149,20 +143,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Denoise Latents",
|
||||
"tags": ["denoise", "latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
@ -474,29 +454,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("Latents to Image")
|
||||
@tags("latents", "image", "vae")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -574,24 +554,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@title("Resize Latents")
|
||||
@tags("latents", "resize")
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||
}
|
||||
width: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@ -616,23 +602,21 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Scale Latents")
|
||||
@tags("latents", "resize")
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||
}
|
||||
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@ -658,22 +642,23 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Image to Latents")
|
||||
@tags("latents", "image", "vae")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The image to encode")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||
}
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
|
@ -2,134 +2,83 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import IntegerOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Add Integers")
|
||||
@tags("math")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Subtract Integers")
|
||||
@tags("math")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Multiply Integers")
|
||||
@tags("math")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Divide Integers")
|
||||
@tags("math")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
@title("Random Integer")
|
||||
@tags("math")
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||
}
|
||||
# Inputs
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
||||
|
@ -1,18 +1,22 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ...version import __version__
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModelExcludeNull):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
@ -43,37 +47,37 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Union[float, None] = Field(
|
||||
strength: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_store: Union[float, None] = Field(
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_negative_aesthetic_store: Union[float, None] = Field(
|
||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
@ -91,69 +95,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||
|
||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@title("Metadata Accumulator")
|
||||
@tags("metadata")
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(
|
||||
generation_mode: str = InputField(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||
width: int = InputField(description="The width parameter")
|
||||
height: int = InputField(description="The height parameter")
|
||||
seed: int = InputField(description="The seed used for noise generation")
|
||||
rand_device: str = InputField(description="The device used for random number generation")
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
init_image: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
positive_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The positive style prompt parameter",
|
||||
)
|
||||
negative_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The negative style prompt parameter",
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
refiner_cfg_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_score: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
refiner_steps: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of steps used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_score: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
refiner_scheduler: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_start: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
@ -39,13 +50,11 @@ class VaeField(BaseModel):
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@title("Main Model Loader")
|
||||
@tags("model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("LoRA Loader")
|
||||
@tags("lora", "model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("SDXL LoRA Loader")
|
||||
@tags("sdxl", "lora", "model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = Field(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||
)
|
||||
clip: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
)
|
||||
clip2: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
# fmt: on
|
||||
# Outputs
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("VAE Loader")
|
||||
@tags("vae", "model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
type: Literal["vae_loader"] = "vae_loader"
|
||||
|
||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
"type_hints": {"vae_model": "vae_model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
|
@ -1,19 +1,24 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, validator
|
||||
import torch
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
@ -61,14 +66,12 @@ Nodes
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
# fmt: on
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
)
|
||||
|
||||
|
||||
@title("Noise")
|
||||
@tags("latents", "noise")
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use",
|
||||
description=FieldDescriptions.seed,
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
width: int = Field(
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the resulting noise",
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = Field(
|
||||
height: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the resulting noise",
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
use_cpu: bool = Field(
|
||||
use_cpu: bool = InputField(
|
||||
default=True,
|
||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
|
@ -1,37 +1,43 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.util import choose_torch_device
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from tqdm import tqdm
|
||||
from .model import ClipField
|
||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||
from .compel import CompelOutput
|
||||
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
@ -51,13 +57,15 @@ ORT_TO_NP_TYPE = {
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
@title("ONNX Prompt (Raw)")
|
||||
@tags("onnx", "prompt")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
@ -126,7 +134,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@ -134,25 +142,48 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# Text to image
|
||||
@title("ONNX Text to Latents")
|
||||
@tags("latents", "inference", "txt2img", "onnx")
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
noise: LatentsField = InputField(
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@ -166,20 +197,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@ -300,26 +317,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("ONNX Latents to Image")
|
||||
@tags("latents", "image", "vae", "onnx")
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.denoised_latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@ -373,89 +392,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
model_name = "stable-diffusion-v1-5"
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {model_name}!")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
@ -464,22 +407,17 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@title("ONNX Model Loader")
|
||||
@tags("onnx", "model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
model: OnnxModelField = InputField(
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
|
@ -1,73 +1,64 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
from typing import Literal, Optional
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseOut,
|
||||
CubicEaseInOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseOut,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseOut,
|
||||
SineEaseInOut,
|
||||
SineEaseIn,
|
||||
SineEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Float Range")
|
||||
@tags("math", "range")
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||
}
|
||||
start: float = InputField(default=5, description="The first value of the range")
|
||||
stop: float = InputField(default=10, description="The last value of the range")
|
||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
@ -108,37 +99,32 @@ EASING_FUNCTIONS_MAP = {
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@title("Step Param Easing")
|
||||
@tags("step", "easing")
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||
}
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
|
@ -1,83 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.prompt import PromptOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""A string output"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = Field(default=None, description="The output string")
|
||||
|
||||
|
||||
class ParamStringInvocation(BaseInvocation):
|
||||
"""A string parameter"""
|
||||
|
||||
type: Literal["param_string"] = "param_string"
|
||||
text: str = Field(default="", description="The string value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
class ParamPromptInvocation(BaseInvocation):
|
||||
"""A prompt input parameter"""
|
||||
|
||||
type: Literal["param_prompt"] = "param_prompt"
|
||||
prompt: str = Field(default="", description="The prompt value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
||||
return PromptOutput(prompt=self.prompt)
|
494
invokeai/app/invocations/primitives.py
Normal file
494
invokeai/app/invocations/primitives.py
Normal file
@ -0,0 +1,494 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from anyio import Condition
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
||||
- primitive nodes
|
||||
- primitive outputs
|
||||
- primitive collection outputs
|
||||
"""
|
||||
|
||||
# region Boolean
|
||||
|
||||
|
||||
class BooleanOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single boolean"""
|
||||
|
||||
type: Literal["boolean_output"] = "boolean_output"
|
||||
a: bool = OutputField(description="The output boolean")
|
||||
|
||||
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of booleans"""
|
||||
|
||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[bool] = OutputField(
|
||||
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Boolean")
|
||||
@tags("primitives", "boolean")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
# Inputs
|
||||
a: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Boolean Collection")
|
||||
@tags("primitives", "boolean", "collection")
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
type: Literal["boolean_collection"] = "boolean_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Integer
|
||||
|
||||
|
||||
class IntegerOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single integer"""
|
||||
|
||||
type: Literal["integer_output"] = "integer_output"
|
||||
a: int = OutputField(description="The output integer")
|
||||
|
||||
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of integers"""
|
||||
|
||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = OutputField(
|
||||
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Integer")
|
||||
@tags("primitives", "integer")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
type: Literal["integer"] = "integer"
|
||||
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Integer Collection")
|
||||
@tags("primitives", "integer", "collection")
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
type: Literal["integer_collection"] = "integer_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[int] = InputField(
|
||||
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Float
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single float"""
|
||||
|
||||
type: Literal["float_output"] = "float_output"
|
||||
a: float = OutputField(description="The output float")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
type: Literal["float_collection_output"] = "float_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = OutputField(
|
||||
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Float")
|
||||
@tags("primitives", "float")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
type: Literal["float"] = "float"
|
||||
|
||||
# Inputs
|
||||
param: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(a=self.param)
|
||||
|
||||
|
||||
@title("Float Collection")
|
||||
@tags("primitives", "float", "collection")
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region String
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single string"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = OutputField(description="The output string")
|
||||
|
||||
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
type: Literal["string_collection_output"] = "string_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[str] = OutputField(
|
||||
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
|
||||
@title("String")
|
||||
@tags("primitives", "string")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
# Inputs
|
||||
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
@title("String Collection")
|
||||
@tags("primitives", "string", "collection")
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
type: Literal["string_collection"] = "string_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = OutputField(description="The output image")
|
||||
width: int = OutputField(description="The width of the image in pixels")
|
||||
height: int = OutputField(description="The height of the image in pixels")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
type: Literal["image_collection_output"] = "image_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = OutputField(
|
||||
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@tags("primitives", "image")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["image"] = "image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Collection")
|
||||
@tags("primitives", "image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ImageField] = InputField(
|
||||
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of latents tensors"""
|
||||
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
default_factory=list,
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Latents Primitive")
|
||||
@tags("primitives", "latents")
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
type: Literal["latents"] = "latents"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@title("Latents Collection")
|
||||
@tags("primitives", "latents", "collection")
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
type: Literal["latents_collection"] = "latents_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Color
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ColorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single color"""
|
||||
|
||||
type: Literal["color_output"] = "color_output"
|
||||
color: ColorField = OutputField(description="The output color")
|
||||
|
||||
|
||||
class ColorCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of colors"""
|
||||
|
||||
type: Literal["color_collection_output"] = "color_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ColorField] = OutputField(
|
||||
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@tags("primitives", "color")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
type: Literal["color"] = "color"
|
||||
|
||||
# Inputs
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||
return ColorOutput(color=self.color)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
type: Literal["conditioning_output"] = "conditioning_output"
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||
|
||||
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
default_factory=list,
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Conditioning Primitive")
|
||||
@tags("primitives", "conditioning")
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
type: Literal["conditioning"] = "conditioning"
|
||||
|
||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@title("Conditioning Collection")
|
||||
@tags("primitives", "conditioning", "collection")
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
type: Literal["conditioning_collection"] = "conditioning_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
return ConditioningCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
@ -1,59 +1,28 @@
|
||||
from os.path import exists
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from pydantic import validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"prompt",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PromptCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a collection of prompts"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||
|
||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
||||
count: int = Field(description="The size of the prompt collection")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
|
||||
|
||||
|
||||
@title("Dynamic Prompt")
|
||||
@tags("prompt", "collection")
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||
}
|
||||
# Inputs
|
||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
@ -61,27 +30,26 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@title("Prompts from File")
|
||||
@tags("prompt", "file")
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||
type: Literal["prompt_from_file"] = "prompt_from_file"
|
||||
|
||||
# Inputs
|
||||
file_path: str = Field(description="Path to prompt text file")
|
||||
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||
}
|
||||
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
|
||||
pre_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
post_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
|
||||
@validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
@ -89,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
raise ValueError(FileNotFoundError)
|
||||
return v
|
||||
|
||||
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
||||
def promptsFromFile(
|
||||
self,
|
||||
file_path: str,
|
||||
pre_prompt: Union[str, None],
|
||||
post_prompt: Union[str, None],
|
||||
start_line: int,
|
||||
max_prompts: int,
|
||||
):
|
||||
prompts = []
|
||||
start_line -= 1
|
||||
end_line = start_line + max_prompts
|
||||
@ -103,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
break
|
||||
return prompts
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
)
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
@ -1,55 +1,55 @@
|
||||
import torch
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
# fmt: on
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("SDXL Main Model Loader")
|
||||
@tags("model", "sdxl")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Model Loader")
|
||||
@tags("model", "sdxl", "refiner")
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
|
@ -6,13 +6,12 @@ import cv2 as cv
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from realesrgan import RealESRGANer
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
|
||||
]
|
||||
|
||||
|
||||
@title("Upscale (RealESRGAN)")
|
||||
@tags("esrgan", "upscale")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
@ -1,31 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
@ -36,50 +13,6 @@ class ProgressImage(BaseModel):
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
|
@ -2,7 +2,7 @@ from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocat
|
||||
from ..invocations.image import ImageNSFWBlurInvocation
|
||||
from ..invocations.noise import NoiseInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from ..invocations.primitives import IntegerInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": ParamIntInvocation(id="width", a=512),
|
||||
"height": ParamIntInvocation(id="height", a=512),
|
||||
"seed": ParamIntInvocation(id="seed", a=-1),
|
||||
"width": IntegerInvocation(id="width", a=512),
|
||||
"height": IntegerInvocation(id="height", a=512),
|
||||
"seed": IntegerInvocation(id="seed", a=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
|
@ -3,16 +3,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
@ -22,7 +13,11 @@ from ..invocations import *
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["iterate_output"] = "iterate_output"
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"item",
|
||||
]
|
||||
}
|
||||
item: Any = OutputField(
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||
)
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||
collection: list[Any] = InputField(
|
||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["collect_output"] = "collect_output"
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"collection",
|
||||
]
|
||||
}
|
||||
collection: list[Any] = OutputField(
|
||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||
)
|
||||
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["collect"] = "collect"
|
||||
|
||||
item: Any = Field(
|
||||
item: Any = InputField(
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
default=None,
|
||||
ui_type=UIType.CollectionItem,
|
||||
title="Collection Item",
|
||||
input=Input.Connection,
|
||||
)
|
||||
collection: list[Any] = Field(
|
||||
description="The collection, will be provided on execution",
|
||||
default_factory=list,
|
||||
collection: list[Any] = InputField(
|
||||
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
|
@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@ -198,6 +200,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -220,6 +232,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -319,6 +337,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -395,7 +424,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
@ -498,6 +527,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
@ -513,9 +543,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
is_intermediate,
|
||||
starred
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@ -527,6 +558,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
starred: bool = Field(description="Whether this image is starred.")
|
||||
"""Whether this image is starred."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
- `starred`: change whether the image is starred
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||
@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state")
|
||||
"""The image's new `starred` state."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
starred = image_dict.get("starred", False)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
starred=starred,
|
||||
)
|
||||
|
@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Invoke
|
||||
try:
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
outputs = invocation.invoke(
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
|
@ -45,7 +45,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
@ -61,6 +61,7 @@
|
||||
"@dagrejs/graphlib": "^2.1.13",
|
||||
"@dnd-kit/core": "^6.0.8",
|
||||
"@dnd-kit/modifiers": "^6.0.1",
|
||||
"@dnd-kit/utilities": "^3.2.1",
|
||||
"@emotion/react": "^11.11.1",
|
||||
"@emotion/styled": "^11.11.0",
|
||||
"@floating-ui/react-dom": "^2.0.1",
|
||||
|
34
invokeai/frontend/web/scripts/colors.js
Normal file
34
invokeai/frontend/web/scripts/colors.js
Normal file
@ -0,0 +1,34 @@
|
||||
export const COLORS = {
|
||||
reset: '\x1b[0m',
|
||||
bright: '\x1b[1m',
|
||||
dim: '\x1b[2m',
|
||||
underscore: '\x1b[4m',
|
||||
blink: '\x1b[5m',
|
||||
reverse: '\x1b[7m',
|
||||
hidden: '\x1b[8m',
|
||||
|
||||
fg: {
|
||||
black: '\x1b[30m',
|
||||
red: '\x1b[31m',
|
||||
green: '\x1b[32m',
|
||||
yellow: '\x1b[33m',
|
||||
blue: '\x1b[34m',
|
||||
magenta: '\x1b[35m',
|
||||
cyan: '\x1b[36m',
|
||||
white: '\x1b[37m',
|
||||
gray: '\x1b[90m',
|
||||
crimson: '\x1b[38m',
|
||||
},
|
||||
bg: {
|
||||
black: '\x1b[40m',
|
||||
red: '\x1b[41m',
|
||||
green: '\x1b[42m',
|
||||
yellow: '\x1b[43m',
|
||||
blue: '\x1b[44m',
|
||||
magenta: '\x1b[45m',
|
||||
cyan: '\x1b[46m',
|
||||
white: '\x1b[47m',
|
||||
gray: '\x1b[100m',
|
||||
crimson: '\x1b[48m',
|
||||
},
|
||||
};
|
@ -1,23 +1,83 @@
|
||||
import fs from 'node:fs';
|
||||
import openapiTS from 'openapi-typescript';
|
||||
import { COLORS } from './colors.js';
|
||||
|
||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
||||
|
||||
async function main() {
|
||||
process.stdout.write(
|
||||
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
|
||||
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
|
||||
);
|
||||
const types = await openapiTS(OPENAPI_URL, {
|
||||
exportType: true,
|
||||
transform: (schemaObject) => {
|
||||
transform: (schemaObject, metadata) => {
|
||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||
}
|
||||
|
||||
/**
|
||||
* Because invocations may have required fields that accept connection input, the generated
|
||||
* types may be incorrect.
|
||||
*
|
||||
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
|
||||
* connection input, it should be optional on instantiation of the field.
|
||||
*
|
||||
* To handle this, the schema exposes an `input` property that can be used to determine if the
|
||||
* field accepts connection input. If it does, we can make the field optional.
|
||||
*/
|
||||
|
||||
// Check if we are generating types for an invocation
|
||||
const isInvocationPath = metadata.path.match(
|
||||
/^#\/components\/schemas\/\w*Invocation$/
|
||||
);
|
||||
|
||||
const hasInvocationProperties =
|
||||
schemaObject.properties &&
|
||||
['id', 'is_intermediate', 'type'].every(
|
||||
(prop) => prop in schemaObject.properties
|
||||
);
|
||||
|
||||
if (isInvocationPath && hasInvocationProperties) {
|
||||
// We only want to make fields optional if they are required
|
||||
if (!Array.isArray(schemaObject?.required)) {
|
||||
schemaObject.required = ['id', 'type'];
|
||||
return;
|
||||
}
|
||||
|
||||
schemaObject.required.forEach((prop) => {
|
||||
const acceptsConnection = ['any', 'connection'].includes(
|
||||
schemaObject.properties?.[prop]?.['input']
|
||||
);
|
||||
|
||||
if (acceptsConnection) {
|
||||
// remove this prop from the required array
|
||||
const invocationName = metadata.path.split('/').pop();
|
||||
console.log(
|
||||
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
|
||||
);
|
||||
schemaObject.required = schemaObject.required.filter(
|
||||
(r) => r !== prop
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
schemaObject.required = [
|
||||
...new Set(schemaObject.required.concat(['id', 'type'])),
|
||||
];
|
||||
|
||||
return;
|
||||
}
|
||||
// if (
|
||||
// 'input' in schemaObject &&
|
||||
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
|
||||
// ) {
|
||||
// schemaObject.required = false;
|
||||
// }
|
||||
},
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
process.stdout.write(` OK!\r\n`);
|
||||
process.stdout.write(`\nOK!\r\n`);
|
||||
}
|
||||
|
||||
main();
|
||||
|
@ -1,8 +1,12 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import {
|
||||
ctrlKeyPressed,
|
||||
metaKeyPressed,
|
||||
shiftKeyPressed,
|
||||
} from 'features/ui/store/hotkeysSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
@ -16,11 +20,11 @@ import React, { memo } from 'react';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
const globalHotkeysSelector = createSelector(
|
||||
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
|
||||
(hotkeys, ui) => {
|
||||
const { shift } = hotkeys;
|
||||
[stateSelector],
|
||||
({ hotkeys, ui }) => {
|
||||
const { shift, ctrl, meta } = hotkeys;
|
||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||
return { shift, shouldPinGallery, shouldPinParametersPanel };
|
||||
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
|
||||
*/
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
|
||||
globalHotkeysSelector
|
||||
);
|
||||
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
|
||||
useAppSelector(globalHotkeysSelector);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
useHotkeys(
|
||||
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
|
||||
} else {
|
||||
shift && dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
if (isHotkeyPressed('ctrl')) {
|
||||
!ctrl && dispatch(ctrlKeyPressed(true));
|
||||
} else {
|
||||
ctrl && dispatch(ctrlKeyPressed(false));
|
||||
}
|
||||
if (isHotkeyPressed('meta')) {
|
||||
!meta && dispatch(metaKeyPressed(true));
|
||||
} else {
|
||||
meta && dispatch(metaKeyPressed(false));
|
||||
}
|
||||
},
|
||||
{ keyup: true, keydown: true },
|
||||
[shift]
|
||||
[shift, ctrl, meta]
|
||||
);
|
||||
|
||||
useHotkeys('o', () => {
|
||||
|
@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import '../../i18n';
|
||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
@ -80,9 +80,9 @@ const InvokeAIUI = ({
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<AppDndContext>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
</ImageDndContext>
|
||||
</AppDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
|
@ -19,7 +19,8 @@ type LoggerNamespace =
|
||||
| 'nodes'
|
||||
| 'system'
|
||||
| 'socketio'
|
||||
| 'session';
|
||||
| 'session'
|
||||
| 'dnd';
|
||||
|
||||
export const logger = (namespace: LoggerNamespace) =>
|
||||
$logger.get().child({ namespace });
|
||||
|
@ -15,7 +15,7 @@ export const actionsDenylist = [
|
||||
'socket/socketGeneratorProgress',
|
||||
'socket/appSocketGeneratorProgress',
|
||||
// every time user presses shift
|
||||
'hotkeys/shiftKeyPressed',
|
||||
// 'hotkeys/shiftKeyPressed',
|
||||
// this happens after every state change
|
||||
'@@REMEMBER_PERSISTED',
|
||||
];
|
||||
|
@ -15,6 +15,7 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
@ -27,8 +28,8 @@ import {
|
||||
addImageDeletedFulfilledListener,
|
||||
addImageDeletedPendingListener,
|
||||
addImageDeletedRejectedListener,
|
||||
addRequestedSingleImageDeletionListener,
|
||||
addRequestedMultipleImageDeletionListener,
|
||||
addRequestedSingleImageDeletionListener,
|
||||
} from './listeners/imageDeleted';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import {
|
||||
@ -79,6 +80,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addImagesStarredListener } from './listeners/imagesStarred';
|
||||
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -120,6 +123,10 @@ addImageDeletedRejectedListener();
|
||||
addDeleteBoardAndImagesFulfilledListener();
|
||||
addImageToDeleteSelectedListener();
|
||||
|
||||
// Image starred
|
||||
addImagesStarredListener();
|
||||
addImagesUnstarredListener();
|
||||
|
||||
// User Invoked
|
||||
addUserInvokedCanvasListener();
|
||||
addUserInvokedNodesListener();
|
||||
@ -129,6 +136,7 @@ addSessionReadyToInvokeListener();
|
||||
|
||||
// Canvas actions
|
||||
addCanvasSavedToGalleryListener();
|
||||
addCanvasMaskSavedToGalleryListener();
|
||||
addCanvasDownloadedAsImageListener();
|
||||
addCanvasCopiedToClipboardListener();
|
||||
addCanvasMergedListener();
|
||||
|
@ -0,0 +1,60 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addCanvasMaskSavedToGalleryListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMaskSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
|
||||
const canvasBlobsAndImageData = await getCanvasData(
|
||||
state.canvas.layerState,
|
||||
state.canvas.boundingBoxCoordinates,
|
||||
state.canvas.boundingBoxDimensions,
|
||||
state.canvas.isMaskEnabled,
|
||||
state.canvas.shouldPreserveMaskedArea
|
||||
);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { maskBlob } = canvasBlobsAndImageData;
|
||||
|
||||
if (!maskBlob) {
|
||||
log.error('Problem getting mask layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Saving Mask',
|
||||
description: 'Unable to export mask',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'mask',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
crop_visible: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: 'Mask Saved to Assets' },
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,16 +1,20 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
fieldImageValueChanged,
|
||||
workflowExposedFieldAdded,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '../';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
|
||||
export const dndDropped = createAction<{
|
||||
overData: TypesafeDroppableData;
|
||||
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: dndDropped,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const log = logger('images');
|
||||
const log = logger('dnd');
|
||||
const { activeData, overData } = action.payload;
|
||||
|
||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
|
||||
{ activeData, overData },
|
||||
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||
);
|
||||
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||
log.debug(
|
||||
{ activeData: parseify(activeData), overData: parseify(overData) },
|
||||
'Node field dropped'
|
||||
);
|
||||
} else {
|
||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||
}
|
||||
|
||||
if (
|
||||
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
|
||||
activeData.payloadType === 'NODE_FIELD'
|
||||
) {
|
||||
const { nodeId, field } = activeData.payload;
|
||||
dispatch(
|
||||
workflowExposedFieldAdded({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on current image
|
||||
*/
|
||||
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
fieldImageValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: activeData.payload.imageDTO,
|
||||
|
@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
||||
const { nodeId, fieldName } = postUploadAction;
|
||||
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
|
||||
dispatch(
|
||||
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
|
||||
import { ImageDTO } from '../../../../../services/api/types';
|
||||
|
||||
export const addImagesStarredListener = () => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: starredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
const { selection } = state.gallery;
|
||||
const updatedSelection: ImageDTO[] = [];
|
||||
|
||||
selection.forEach((selectedImageDTO) => {
|
||||
if (starredImages.includes(selectedImageDTO.image_name)) {
|
||||
updatedSelection.push({
|
||||
...selectedImageDTO,
|
||||
starred: true,
|
||||
});
|
||||
} else {
|
||||
updatedSelection.push(selectedImageDTO);
|
||||
}
|
||||
});
|
||||
dispatch(selectionChanged(updatedSelection));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,30 @@
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { startAppListening } from '..';
|
||||
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
|
||||
import { ImageDTO } from '../../../../../services/api/types';
|
||||
|
||||
export const addImagesUnstarredListener = () => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: unstarredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
const { selection } = state.gallery;
|
||||
const updatedSelection: ImageDTO[] = [];
|
||||
|
||||
selection.forEach((selectedImageDTO) => {
|
||||
if (unstarredImages.includes(selectedImageDTO.image_name)) {
|
||||
updatedSelection.push({
|
||||
...selectedImageDTO,
|
||||
starred: false,
|
||||
});
|
||||
} else {
|
||||
updatedSelection.push(selectedImageDTO);
|
||||
}
|
||||
});
|
||||
dispatch(selectionChanged(updatedSelection));
|
||||
},
|
||||
});
|
||||
};
|
@ -15,12 +15,21 @@ import {
|
||||
setShouldUseSDXLRefiner,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
||||
import {
|
||||
mainModelsAdapter,
|
||||
modelsApi,
|
||||
vaeModelsAdapter,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { TypeGuardFor } from 'services/api/types';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
startAppListening({
|
||||
predicate: (state, action) =>
|
||||
predicate: (
|
||||
action
|
||||
): action is TypeGuardFor<
|
||||
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||
> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
|
||||
);
|
||||
|
||||
const currentModel = getState().generation.model;
|
||||
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||
|
||||
const isCurrentModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model &&
|
||||
m?.model_type === currentModel?.model_type
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
|
||||
if (!firstModel) {
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainOrOnnxModel.safeParse(firstModel);
|
||||
const isCurrentModelAvailable = currentModel
|
||||
? models.some(
|
||||
(m) =>
|
||||
m.model_name === currentModel.model_name &&
|
||||
m.base_model === currentModel.base_model &&
|
||||
m.model_type === currentModel.model_type
|
||||
)
|
||||
: false;
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainOrOnnxModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
predicate: (state, action) =>
|
||||
predicate: (
|
||||
action
|
||||
): action is TypeGuardFor<
|
||||
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||
> =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
|
||||
);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||
|
||||
const isCurrentModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model &&
|
||||
m?.model_type === currentModel?.model_type
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
|
||||
if (!firstModel) {
|
||||
if (models.length === 0) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
dispatch(setShouldUseSDXLRefiner(false));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zSDXLRefinerModel.safeParse(firstModel);
|
||||
const isCurrentModelAvailable = currentModel
|
||||
? models.some(
|
||||
(m) =>
|
||||
m.model_name === currentModel.model_name &&
|
||||
m.base_model === currentModel.base_model &&
|
||||
m.model_type === currentModel.model_type
|
||||
)
|
||||
: false;
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zSDXLRefinerModel.safeParse(models[0]);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
|
@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
|
||||
const log = logger('system');
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
|
||||
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||
|
||||
const nodeTemplates = parseSchema(schemaJSON);
|
||||
|
||||
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
|
||||
|
||||
startAppListening({
|
||||
actionCreator: receivedOpenAPISchema.rejected,
|
||||
effect: () => {
|
||||
effect: (action) => {
|
||||
const log = logger('system');
|
||||
log.error('Problem dereferencing OpenAPI Schema');
|
||||
log.error(
|
||||
{ error: parseify(action.error) },
|
||||
'Problem retrieving OpenAPI Schema'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -19,7 +19,7 @@ import {
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
const nodeDenylist = ['load_image'];
|
||||
|
||||
export const addInvocationCompleteEventListener = () => {
|
||||
startAppListening({
|
||||
|
@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
|
||||
const log = logger('session');
|
||||
const state = getState();
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
dispatch(nodesGraphBuilt(graph));
|
||||
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
||||
|
||||
|
@ -1,86 +1,7 @@
|
||||
import {
|
||||
// CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
// These are old types from the model management UI
|
||||
|
||||
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
||||
|
||||
// export type Model = {
|
||||
// status: ModelStatus;
|
||||
// description: string;
|
||||
// weights: string;
|
||||
// config?: string;
|
||||
// vae?: string;
|
||||
// width?: number;
|
||||
// height?: number;
|
||||
// default?: boolean;
|
||||
// format?: string;
|
||||
// };
|
||||
|
||||
// export type DiffusersModel = {
|
||||
// status: ModelStatus;
|
||||
// description: string;
|
||||
// repo_id?: string;
|
||||
// path?: string;
|
||||
// vae?: {
|
||||
// repo_id?: string;
|
||||
// path?: string;
|
||||
// };
|
||||
// format?: string;
|
||||
// default?: boolean;
|
||||
// };
|
||||
|
||||
// export type ModelList = Record<string, Model & DiffusersModel>;
|
||||
|
||||
// export type FoundModel = {
|
||||
// name: string;
|
||||
// location: string;
|
||||
// };
|
||||
|
||||
// export type InvokeModelConfigProps = {
|
||||
// name: string | undefined;
|
||||
// description: string | undefined;
|
||||
// config: string | undefined;
|
||||
// weights: string | undefined;
|
||||
// vae: string | undefined;
|
||||
// width: number | undefined;
|
||||
// height: number | undefined;
|
||||
// default: boolean | undefined;
|
||||
// format: string | undefined;
|
||||
// };
|
||||
|
||||
// export type InvokeDiffusersModelConfigProps = {
|
||||
// name: string | undefined;
|
||||
// description: string | undefined;
|
||||
// repo_id: string | undefined;
|
||||
// path: string | undefined;
|
||||
// default: boolean | undefined;
|
||||
// format: string | undefined;
|
||||
// vae: {
|
||||
// repo_id: string | undefined;
|
||||
// path: string | undefined;
|
||||
// };
|
||||
// };
|
||||
|
||||
// export type InvokeModelConversionProps = {
|
||||
// model_name: string;
|
||||
// save_location: string;
|
||||
// custom_location: string | null;
|
||||
// };
|
||||
|
||||
// export type InvokeModelMergingProps = {
|
||||
// models_to_merge: string[];
|
||||
// alpha: number;
|
||||
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
||||
// force: boolean;
|
||||
// merged_model_name: string;
|
||||
// model_merge_save_path: string | null;
|
||||
// };
|
||||
|
||||
/**
|
||||
* A disable-able application feature
|
||||
*/
|
||||
|
@ -1,16 +1,11 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Icon,
|
||||
Image,
|
||||
useColorMode,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import {
|
||||
IAILoadingImageFallback,
|
||||
IAINoContentFallback,
|
||||
@ -26,22 +21,22 @@ import {
|
||||
useCallback,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { FaImage, FaUpload } from 'react-icons/fa';
|
||||
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
import SelectionOverlay from './SelectionOverlay';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
|
||||
type IAIDndImageProps = {
|
||||
type IAIDndImageProps = FlexProps & {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
||||
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
withResetIcon?: boolean;
|
||||
resetIcon?: ReactElement;
|
||||
resetTooltip?: string;
|
||||
withMetadataOverlay?: boolean;
|
||||
isDragDisabled?: boolean;
|
||||
isDropDisabled?: boolean;
|
||||
@ -58,15 +53,14 @@ type IAIDndImageProps = {
|
||||
noContentFallback?: ReactElement;
|
||||
useThumbailFallback?: boolean;
|
||||
withHoverOverlay?: boolean;
|
||||
children?: JSX.Element;
|
||||
};
|
||||
|
||||
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
const {
|
||||
imageDTO,
|
||||
onClickReset,
|
||||
onError,
|
||||
onClick,
|
||||
withResetIcon = false,
|
||||
withMetadataOverlay = false,
|
||||
isDropDisabled = false,
|
||||
isDragDisabled = false,
|
||||
@ -80,32 +74,36 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
thumbnail = false,
|
||||
resetTooltip = 'Reset',
|
||||
resetIcon = <FaUndo />,
|
||||
noContentFallback = <IAINoContentFallback icon={FaImage} />,
|
||||
useThumbailFallback,
|
||||
withHoverOverlay = false,
|
||||
children,
|
||||
onMouseOver,
|
||||
onMouseOut,
|
||||
} = props;
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
}, []);
|
||||
const handleMouseOut = useCallback(() => {
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
const handleMouseOver = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOver) onMouseOver(e);
|
||||
setIsHovered(true);
|
||||
},
|
||||
[onMouseOver]
|
||||
);
|
||||
const handleMouseOut = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOut) onMouseOut(e);
|
||||
setIsHovered(false);
|
||||
},
|
||||
[onMouseOut]
|
||||
);
|
||||
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
isDisabled: isUploadDisabled,
|
||||
});
|
||||
|
||||
const resetIconShadow = useColorModeValue(
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
|
||||
);
|
||||
|
||||
const uploadButtonStyles = isUploadDisabled
|
||||
? {}
|
||||
: {
|
||||
@ -157,11 +155,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
<IAILoadingImageFallback image={imageDTO} />
|
||||
)
|
||||
}
|
||||
width={imageDTO.width}
|
||||
height={imageDTO.height}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
sx={{
|
||||
w: imageDTO.width,
|
||||
objectFit: 'contain',
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
@ -220,30 +217,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{onClickReset && withResetIcon && imageDTO && (
|
||||
<IAIIconButton
|
||||
onClick={onClickReset}
|
||||
aria-label={resetTooltip}
|
||||
tooltip={resetTooltip}
|
||||
icon={resetIcon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: resetIconShadow,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
</Flex>
|
||||
)}
|
||||
</ImageContextMenu>
|
||||
|
@ -0,0 +1,46 @@
|
||||
import { SystemStyleObject, useColorModeValue } from '@chakra-ui/react';
|
||||
import { MouseEvent, ReactElement, memo } from 'react';
|
||||
import IAIIconButton from './IAIIconButton';
|
||||
|
||||
type Props = {
|
||||
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
tooltip: string;
|
||||
icon?: ReactElement;
|
||||
styleOverrides?: SystemStyleObject;
|
||||
};
|
||||
|
||||
const IAIDndImageIcon = (props: Props) => {
|
||||
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||
|
||||
const resetIconShadow = useColorModeValue(
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
|
||||
);
|
||||
return (
|
||||
<IAIIconButton
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
tooltip={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: resetIconShadow,
|
||||
},
|
||||
...styleOverrides,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDndImageIcon);
|
@ -1,22 +1,19 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
useDraggable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { MouseEvent, memo, useRef } from 'react';
|
||||
import { Box, BoxProps } from '@chakra-ui/react';
|
||||
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||
import { TypesafeDraggableData } from 'features/dnd/types';
|
||||
import { memo, useRef } from 'react';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type IAIDraggableProps = {
|
||||
type IAIDraggableProps = BoxProps & {
|
||||
disabled?: boolean;
|
||||
data?: TypesafeDraggableData;
|
||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
||||
};
|
||||
|
||||
const IAIDraggable = (props: IAIDraggableProps) => {
|
||||
const { data, disabled, onClick } = props;
|
||||
const { data, disabled, ...rest } = props;
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const { attributes, listeners, setNodeRef } = useDraggable({
|
||||
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
|
||||
id: dndId.current,
|
||||
disabled,
|
||||
data,
|
||||
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
||||
|
||||
return (
|
||||
<Box
|
||||
onClick={onClick}
|
||||
ref={setNodeRef}
|
||||
position="absolute"
|
||||
w="full"
|
||||
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
||||
insetInlineStart={0}
|
||||
{...attributes}
|
||||
{...listeners}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,9 +1,7 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import {
|
||||
TypesafeDroppableData,
|
||||
isValidDrop,
|
||||
useDroppable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { ReactNode, memo, useRef } from 'react';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
const { dropLabel, data, disabled } = props;
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const { isOver, setNodeRef, active } = useDroppable({
|
||||
const { isOver, setNodeRef, active } = useDroppableTypesafe({
|
||||
id: dndId.current,
|
||||
disabled,
|
||||
data,
|
||||
|
@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
|
||||
|
||||
type IAINoImageFallbackProps = {
|
||||
label?: string;
|
||||
icon?: As;
|
||||
icon?: As | null;
|
||||
boxSize?: StyleProps['boxSize'];
|
||||
sx?: ChakraProps['sx'];
|
||||
};
|
||||
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||
...props.sx,
|
||||
}}
|
||||
>
|
||||
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
|
||||
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,10 +1,13 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
FormLabelProps,
|
||||
Switch,
|
||||
SwitchProps,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
|
||||
formControlProps?: FormControlProps;
|
||||
formLabelProps?: FormLabelProps;
|
||||
tooltip?: string;
|
||||
helperText?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
|
||||
formControlProps,
|
||||
formLabelProps,
|
||||
tooltip,
|
||||
helperText,
|
||||
...rest
|
||||
} = props;
|
||||
return (
|
||||
@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
|
||||
<FormControl
|
||||
isDisabled={isDisabled}
|
||||
width={width}
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
{...formControlProps}
|
||||
>
|
||||
{label && (
|
||||
<FormLabel
|
||||
my={1}
|
||||
flexGrow={1}
|
||||
sx={{
|
||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||
...formLabelProps?.sx,
|
||||
pe: 4,
|
||||
}}
|
||||
{...formLabelProps}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
<Switch {...rest} />
|
||||
<Flex sx={{ flexDir: 'column', w: 'full' }}>
|
||||
<Flex sx={{ alignItems: 'center', w: 'full' }}>
|
||||
{label && (
|
||||
<FormLabel
|
||||
my={1}
|
||||
flexGrow={1}
|
||||
sx={{
|
||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||
...formLabelProps?.sx,
|
||||
pe: 4,
|
||||
}}
|
||||
{...formLabelProps}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
<Switch {...rest} />
|
||||
</Flex>
|
||||
{helperText && (
|
||||
<FormHelperText>
|
||||
<Text variant="subtext">{helperText}</Text>
|
||||
</FormHelperText>
|
||||
)}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
|
@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
|
||||
accent850,
|
||||
accent900,
|
||||
accent950,
|
||||
baseAlpha50,
|
||||
baseAlpha100,
|
||||
baseAlpha150,
|
||||
baseAlpha200,
|
||||
baseAlpha250,
|
||||
baseAlpha300,
|
||||
baseAlpha350,
|
||||
baseAlpha400,
|
||||
baseAlpha450,
|
||||
baseAlpha500,
|
||||
baseAlpha550,
|
||||
baseAlpha600,
|
||||
baseAlpha650,
|
||||
baseAlpha700,
|
||||
baseAlpha750,
|
||||
baseAlpha800,
|
||||
baseAlpha850,
|
||||
baseAlpha900,
|
||||
baseAlpha950,
|
||||
accentAlpha50,
|
||||
accentAlpha100,
|
||||
accentAlpha150,
|
||||
accentAlpha200,
|
||||
accentAlpha250,
|
||||
accentAlpha300,
|
||||
accentAlpha350,
|
||||
accentAlpha400,
|
||||
accentAlpha450,
|
||||
accentAlpha500,
|
||||
accentAlpha550,
|
||||
accentAlpha600,
|
||||
accentAlpha650,
|
||||
accentAlpha700,
|
||||
accentAlpha750,
|
||||
accentAlpha800,
|
||||
accentAlpha850,
|
||||
accentAlpha900,
|
||||
accentAlpha950,
|
||||
] = useToken('colors', [
|
||||
'base.50',
|
||||
'base.100',
|
||||
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
|
||||
'accent.850',
|
||||
'accent.900',
|
||||
'accent.950',
|
||||
'baseAlpha.50',
|
||||
'baseAlpha.100',
|
||||
'baseAlpha.150',
|
||||
'baseAlpha.200',
|
||||
'baseAlpha.250',
|
||||
'baseAlpha.300',
|
||||
'baseAlpha.350',
|
||||
'baseAlpha.400',
|
||||
'baseAlpha.450',
|
||||
'baseAlpha.500',
|
||||
'baseAlpha.550',
|
||||
'baseAlpha.600',
|
||||
'baseAlpha.650',
|
||||
'baseAlpha.700',
|
||||
'baseAlpha.750',
|
||||
'baseAlpha.800',
|
||||
'baseAlpha.850',
|
||||
'baseAlpha.900',
|
||||
'baseAlpha.950',
|
||||
'accentAlpha.50',
|
||||
'accentAlpha.100',
|
||||
'accentAlpha.150',
|
||||
'accentAlpha.200',
|
||||
'accentAlpha.250',
|
||||
'accentAlpha.300',
|
||||
'accentAlpha.350',
|
||||
'accentAlpha.400',
|
||||
'accentAlpha.450',
|
||||
'accentAlpha.500',
|
||||
'accentAlpha.550',
|
||||
'accentAlpha.600',
|
||||
'accentAlpha.650',
|
||||
'accentAlpha.700',
|
||||
'accentAlpha.750',
|
||||
'accentAlpha.800',
|
||||
'accentAlpha.850',
|
||||
'accentAlpha.900',
|
||||
'accentAlpha.950',
|
||||
]);
|
||||
|
||||
return {
|
||||
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
|
||||
accent850,
|
||||
accent900,
|
||||
accent950,
|
||||
baseAlpha50,
|
||||
baseAlpha100,
|
||||
baseAlpha150,
|
||||
baseAlpha200,
|
||||
baseAlpha250,
|
||||
baseAlpha300,
|
||||
baseAlpha350,
|
||||
baseAlpha400,
|
||||
baseAlpha450,
|
||||
baseAlpha500,
|
||||
baseAlpha550,
|
||||
baseAlpha600,
|
||||
baseAlpha650,
|
||||
baseAlpha700,
|
||||
baseAlpha750,
|
||||
baseAlpha800,
|
||||
baseAlpha850,
|
||||
baseAlpha900,
|
||||
baseAlpha950,
|
||||
accentAlpha50,
|
||||
accentAlpha100,
|
||||
accentAlpha150,
|
||||
accentAlpha200,
|
||||
accentAlpha250,
|
||||
accentAlpha300,
|
||||
accentAlpha350,
|
||||
accentAlpha400,
|
||||
accentAlpha450,
|
||||
accentAlpha500,
|
||||
accentAlpha550,
|
||||
accentAlpha600,
|
||||
accentAlpha650,
|
||||
accentAlpha700,
|
||||
accentAlpha750,
|
||||
accentAlpha800,
|
||||
accentAlpha850,
|
||||
accentAlpha900,
|
||||
accentAlpha950,
|
||||
};
|
||||
};
|
||||
|
@ -1,4 +1,10 @@
|
||||
/**
|
||||
* Serialize an object to JSON and back to a new object
|
||||
*/
|
||||
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
|
||||
export const parseify = (obj: unknown) => {
|
||||
try {
|
||||
return JSON.parse(JSON.stringify(obj));
|
||||
} catch {
|
||||
return 'Error parsing object';
|
||||
}
|
||||
};
|
||||
|
@ -2,10 +2,11 @@ import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
||||
import {
|
||||
canvasSelector,
|
||||
isStagingSelector,
|
||||
@ -22,7 +23,7 @@ import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaMask, FaTrash } from 'react-icons/fa';
|
||||
import { FaMask, FaSave, FaTrash } from 'react-icons/fa';
|
||||
|
||||
export const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
@ -102,6 +103,10 @@ const IAICanvasMaskOptions = () => {
|
||||
const handleToggleEnableMask = () =>
|
||||
dispatch(setIsMaskEnabled(!isMaskEnabled));
|
||||
|
||||
const handleSaveMask = async () => {
|
||||
dispatch(canvasMaskSavedToGallery());
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
@ -134,6 +139,9 @@ const IAICanvasMaskOptions = () => {
|
||||
pickerColor={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||
Save Mask
|
||||
</IAIButton>
|
||||
<IAIButton size="sm" leftIcon={<FaTrash />} onClick={handleClearMask}>
|
||||
{t('unifiedCanvas.clearMask')} (Shift+C)
|
||||
</IAIButton>
|
||||
|
@ -3,6 +3,10 @@ import { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||
|
||||
export const canvasMaskSavedToGallery = createAction(
|
||||
'canvas/canvasMaskSavedToGallery'
|
||||
);
|
||||
|
||||
export const canvasCopiedToClipboard = createAction(
|
||||
'canvas/canvasCopiedToClipboard'
|
||||
);
|
||||
|
@ -4,14 +4,16 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
} from 'features/dnd/types';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/types';
|
||||
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetImageChanged,
|
||||
@ -119,11 +121,15 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
/>
|
||||
>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <FaUndo /> : undefined}
|
||||
tooltip="Reset Control Image"
|
||||
/>
|
||||
</IAIDndImage>
|
||||
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -143,10 +149,13 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
isDropDisabled={!isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
/>
|
||||
>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <FaUndo /> : undefined}
|
||||
tooltip="Reset Control Image"
|
||||
/>
|
||||
</IAIDndImage>
|
||||
</Box>
|
||||
{pendingControlImages.includes(controlNetId) && (
|
||||
<Flex
|
||||
|
@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
|
||||
/**
|
||||
* Any ControlNet Processor node, with its parameters flagged as required
|
||||
*/
|
||||
export type RequiredControlNetProcessorNode =
|
||||
export type RequiredControlNetProcessorNode = O.Required<
|
||||
| RequiredCannyImageProcessorInvocation
|
||||
| RequiredContentShuffleImageProcessorInvocation
|
||||
| RequiredHedImageProcessorInvocation
|
||||
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
|
||||
| RequiredNormalbaeImageProcessorInvocation
|
||||
| RequiredOpenposeImageProcessorInvocation
|
||||
| RequiredPidiImageProcessorInvocation
|
||||
| RequiredZoeDepthImageProcessorInvocation;
|
||||
| RequiredZoeDepthImageProcessorInvocation,
|
||||
'id'
|
||||
>;
|
||||
|
||||
/**
|
||||
* Type guard for CannyImageProcessorInvocation
|
||||
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { some } from 'lodash-es';
|
||||
import { ImageUsage } from './types';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
|
||||
export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
const { generation, canvas, nodes, controlNet } = state;
|
||||
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||
);
|
||||
|
||||
const isNodesImage = nodes.nodes.some((node) => {
|
||||
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'image' && input.value?.image_name === image_name
|
||||
input.type === 'ImageField' && input.value?.image_name === image_name
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -6,23 +6,18 @@ import {
|
||||
useSensor,
|
||||
useSensors,
|
||||
} from '@dnd-kit/core';
|
||||
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
|
||||
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
|
||||
import { DndContextTypesafe } from './DndContextTypesafe';
|
||||
import DragPreview from './DragPreview';
|
||||
import {
|
||||
DndContext,
|
||||
DragEndEvent,
|
||||
DragStartEvent,
|
||||
TypesafeDraggableData,
|
||||
} from './typesafeDnd';
|
||||
import { logger } from 'app/logging/logger';
|
||||
|
||||
type ImageDndContextProps = PropsWithChildren;
|
||||
|
||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
const AppDndContext = (props: PropsWithChildren) => {
|
||||
const [activeDragData, setActiveDragData] =
|
||||
useState<TypesafeDraggableData | null>(null);
|
||||
const log = logger('images');
|
||||
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
|
||||
const handleDragStart = useCallback(
|
||||
(event: DragStartEvent) => {
|
||||
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
||||
log.trace(
|
||||
{ dragData: parseify(event.active.data.current) },
|
||||
'Drag started'
|
||||
);
|
||||
const activeData = event.active.data.current;
|
||||
if (!activeData) {
|
||||
return;
|
||||
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
|
||||
const handleDragEnd = useCallback(
|
||||
(event: DragEndEvent) => {
|
||||
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
||||
log.trace(
|
||||
{ dragData: parseify(event.active.data.current) },
|
||||
'Drag ended'
|
||||
);
|
||||
const overData = event.over?.data.current;
|
||||
if (!activeDragData || !overData) {
|
||||
return;
|
||||
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
|
||||
const sensors = useSensors(mouseSensor, touchSensor);
|
||||
|
||||
const scaledModifier = useScaledModifer();
|
||||
|
||||
return (
|
||||
<DndContext
|
||||
<DndContextTypesafe
|
||||
onDragStart={handleDragStart}
|
||||
onDragEnd={handleDragEnd}
|
||||
sensors={sensors}
|
||||
collisionDetection={pointerWithin}
|
||||
autoScroll={false}
|
||||
>
|
||||
{props.children}
|
||||
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
|
||||
<DragOverlay
|
||||
dropAnimation={null}
|
||||
modifiers={[scaledModifier]}
|
||||
style={{
|
||||
width: 'min-content',
|
||||
height: 'min-content',
|
||||
cursor: 'none',
|
||||
userSelect: 'none',
|
||||
// expand overlay to prevent cursor from going outside it and displaying
|
||||
padding: '10rem',
|
||||
}}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{activeDragData && (
|
||||
<motion.div
|
||||
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</DragOverlay>
|
||||
</DndContext>
|
||||
</DndContextTypesafe>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageDndContext);
|
||||
export default memo(AppDndContext);
|
@ -0,0 +1,6 @@
|
||||
import { DndContext } from '@dnd-kit/core';
|
||||
import { DndContextTypesafeProps } from '../types';
|
||||
|
||||
export function DndContextTypesafe(props: DndContextTypesafeProps) {
|
||||
return <DndContext {...props} />;
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
||||
import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { TypesafeDraggableData } from './typesafeDnd';
|
||||
import { TypesafeDraggableData } from '../types';
|
||||
|
||||
type OverlayDragImageProps = {
|
||||
dragData: TypesafeDraggableData | null;
|
||||
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'NODE_FIELD') {
|
||||
const { field, fieldTemplate } = props.dragData.payload;
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
p: 2,
|
||||
px: 3,
|
||||
opacity: 0.7,
|
||||
bg: 'base.300',
|
||||
borderRadius: 'base',
|
||||
boxShadow: 'dark-lg',
|
||||
whiteSpace: 'nowrap',
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
>
|
||||
<Text>{field.label || fieldTemplate.title}</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
||||
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'none',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
cursor: 'none',
|
||||
userSelect: 'none',
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
@ -0,0 +1,15 @@
|
||||
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||
import {
|
||||
UseDraggableTypesafeArguments,
|
||||
UseDraggableTypesafeReturnValue,
|
||||
UseDroppableTypesafeArguments,
|
||||
UseDroppableTypesafeReturnValue,
|
||||
} from '../types';
|
||||
|
||||
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
|
||||
return useDroppable(props) as UseDroppableTypesafeReturnValue;
|
||||
}
|
||||
|
||||
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
|
||||
return useDraggable(props) as UseDraggableTypesafeReturnValue;
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import type { Modifier } from '@dnd-kit/core';
|
||||
import { getEventCoordinates } from '@dnd-kit/utilities';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const selectZoom = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ nodes }, activeTabName) =>
|
||||
activeTabName === 'nodes' ? nodes.viewport.zoom : 1
|
||||
);
|
||||
|
||||
/**
|
||||
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
|
||||
*/
|
||||
export const useScaledModifer = () => {
|
||||
const zoom = useAppSelector(selectZoom);
|
||||
const modifier: Modifier = useCallback(
|
||||
({ activatorEvent, draggingNodeRect, transform }) => {
|
||||
if (draggingNodeRect && activatorEvent) {
|
||||
const activatorCoordinates = getEventCoordinates(activatorEvent);
|
||||
|
||||
if (!activatorCoordinates) {
|
||||
return transform;
|
||||
}
|
||||
|
||||
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
|
||||
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
|
||||
|
||||
const x = transform.x + offsetX - draggingNodeRect.width / 2;
|
||||
const y = transform.y + offsetY - draggingNodeRect.height / 2;
|
||||
const scaleX = transform.scaleX * zoom;
|
||||
const scaleY = transform.scaleY * zoom;
|
||||
|
||||
return {
|
||||
x,
|
||||
y,
|
||||
scaleX,
|
||||
scaleY,
|
||||
};
|
||||
}
|
||||
|
||||
return transform;
|
||||
},
|
||||
[zoom]
|
||||
);
|
||||
|
||||
return modifier;
|
||||
};
|
@ -3,7 +3,6 @@ import {
|
||||
Active,
|
||||
Collision,
|
||||
DndContextProps,
|
||||
DndContext as OriginalDndContext,
|
||||
Over,
|
||||
Translate,
|
||||
UseDraggableArguments,
|
||||
@ -11,6 +10,10 @@ import {
|
||||
useDraggable as useOriginalDraggable,
|
||||
useDroppable as useOriginalDroppable,
|
||||
} from '@dnd-kit/core';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type BaseDropData = {
|
||||
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
||||
actionType: 'REMOVE_FROM_BOARD';
|
||||
};
|
||||
|
||||
export type AddFieldToLinearViewDropData = BaseDropData & {
|
||||
actionType: 'ADD_FIELD_TO_LINEAR';
|
||||
};
|
||||
|
||||
export type TypesafeDroppableData =
|
||||
| CurrentImageDropData
|
||||
| InitialImageDropData
|
||||
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
|
||||
| AddToBatchDropData
|
||||
| NodesMultiImageDropData
|
||||
| AddToBoardDropData
|
||||
| RemoveFromBoardDropData;
|
||||
| RemoveFromBoardDropData
|
||||
| AddFieldToLinearViewDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type NodeFieldDraggableData = BaseDragData & {
|
||||
payloadType: 'NODE_FIELD';
|
||||
payload: {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
fieldTemplate: InputFieldTemplate;
|
||||
};
|
||||
};
|
||||
|
||||
export type ImageDraggableData = BaseDragData & {
|
||||
payloadType: 'IMAGE_DTO';
|
||||
payload: { imageDTO: ImageDTO };
|
||||
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
|
||||
payload: { imageDTOs: ImageDTO[] };
|
||||
};
|
||||
|
||||
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
||||
export type TypesafeDraggableData =
|
||||
| NodeFieldDraggableData
|
||||
| ImageDraggableData
|
||||
| ImageDTOsDraggableData;
|
||||
|
||||
interface UseDroppableTypesafeArguments
|
||||
export interface UseDroppableTypesafeArguments
|
||||
extends Omit<UseDroppableArguments, 'data'> {
|
||||
data?: TypesafeDroppableData;
|
||||
}
|
||||
|
||||
type UseDroppableTypesafeReturnValue = Omit<
|
||||
export type UseDroppableTypesafeReturnValue = Omit<
|
||||
ReturnType<typeof useOriginalDroppable>,
|
||||
'active' | 'over'
|
||||
> & {
|
||||
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
export function useDroppable(props: UseDroppableTypesafeArguments) {
|
||||
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
|
||||
}
|
||||
|
||||
interface UseDraggableTypesafeArguments
|
||||
export interface UseDraggableTypesafeArguments
|
||||
extends Omit<UseDraggableArguments, 'data'> {
|
||||
data?: TypesafeDraggableData;
|
||||
}
|
||||
|
||||
type UseDraggableTypesafeReturnValue = Omit<
|
||||
export type UseDraggableTypesafeReturnValue = Omit<
|
||||
ReturnType<typeof useOriginalDraggable>,
|
||||
'active' | 'over'
|
||||
> & {
|
||||
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
export function useDraggable(props: UseDraggableTypesafeArguments) {
|
||||
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
|
||||
}
|
||||
|
||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
export interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||
}
|
||||
|
||||
interface TypesafeOver extends Omit<Over, 'data'> {
|
||||
export interface TypesafeOver extends Omit<Over, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
||||
}
|
||||
|
||||
export const isValidDrop = (
|
||||
overData: TypesafeDroppableData | undefined,
|
||||
active: TypesafeActive | null
|
||||
) => {
|
||||
if (!overData || !active?.data.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { actionType } = overData;
|
||||
const { payloadType } = active.data.current;
|
||||
|
||||
if (overData.id === active.data.current.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (actionType) {
|
||||
case 'SET_CURRENT_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CONTROLNET_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_MULTI_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BATCH':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
return currentBoard !== destinationBoard;
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
case 'REMOVE_FROM_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id;
|
||||
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
interface DragEvent {
|
||||
activatorEvent: Event;
|
||||
active: TypesafeActive;
|
||||
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
|
||||
onDragEnd?(event: DragEndEvent): void;
|
||||
onDragCancel?(event: DragCancelEvent): void;
|
||||
}
|
||||
export function DndContext(props: DndContextTypesafeProps) {
|
||||
return <OriginalDndContext {...props} />;
|
||||
}
|
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
@ -0,0 +1,87 @@
|
||||
import { TypesafeActive, TypesafeDroppableData } from '../types';
|
||||
|
||||
export const isValidDrop = (
|
||||
overData: TypesafeDroppableData | undefined,
|
||||
active: TypesafeActive | null
|
||||
) => {
|
||||
if (!overData || !active?.data.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { actionType } = overData;
|
||||
const { payloadType } = active.data.current;
|
||||
|
||||
if (overData.id === active.data.current.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (actionType) {
|
||||
case 'ADD_FIELD_TO_LINEAR':
|
||||
return payloadType === 'NODE_FIELD';
|
||||
case 'SET_CURRENT_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CONTROLNET_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_MULTI_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BATCH':
|
||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
case 'ADD_TO_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
return currentBoard !== destinationBoard;
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
case 'REMOVE_FROM_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
// Check the payload types
|
||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||
if (!isPayloadValid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id;
|
||||
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
if (payloadType === 'IMAGE_DTOS') {
|
||||
// TODO (multi-select)
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
@ -11,7 +11,6 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
import AutoAddIcon from '../AutoAddIcon';
|
||||
import BoardContextMenu from '../BoardContextMenu';
|
||||
import { AddToBoardDropData } from 'features/dnd/types';
|
||||
|
||||
interface GalleryBoardProps {
|
||||
board: BoardDTO;
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { As, Badge, Flex } from '@chakra-ui/react';
|
||||
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { ReactNode } from 'react';
|
||||
import BoardContextMenu from '../BoardContextMenu';
|
||||
|
@ -1,15 +1,15 @@
|
||||
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import {
|
||||
boardIdSelected,
|
||||
autoAddBoardIdChanged,
|
||||
boardIdSelected,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
|
@ -11,7 +11,6 @@ import {
|
||||
autoAssignBoardOnClickChanged,
|
||||
setGalleryImageMinimumWidth,
|
||||
shouldAutoSwitchChanged,
|
||||
shouldShowDeleteButtonChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { ChangeEvent, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -26,14 +25,12 @@ const selector = createSelector(
|
||||
galleryImageMinimumWidth,
|
||||
shouldAutoSwitch,
|
||||
autoAssignBoardOnClick,
|
||||
shouldShowDeleteButton,
|
||||
} = state.gallery;
|
||||
|
||||
return {
|
||||
galleryImageMinimumWidth,
|
||||
shouldAutoSwitch,
|
||||
autoAssignBoardOnClick,
|
||||
shouldShowDeleteButton,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -43,12 +40,8 @@ const GallerySettingsPopover = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
galleryImageMinimumWidth,
|
||||
shouldAutoSwitch,
|
||||
autoAssignBoardOnClick,
|
||||
shouldShowDeleteButton,
|
||||
} = useAppSelector(selector);
|
||||
const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = useCallback(
|
||||
(v: number) => {
|
||||
@ -68,13 +61,6 @@ const GallerySettingsPopover = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeShowDeleteButton = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(shouldShowDeleteButtonChanged(e.target.checked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
@ -90,7 +76,7 @@ const GallerySettingsPopover = () => {
|
||||
<IAISlider
|
||||
value={galleryImageMinimumWidth}
|
||||
onChange={handleChangeGalleryImageMinimumWidth}
|
||||
min={32}
|
||||
min={45}
|
||||
max={256}
|
||||
hideTooltip={true}
|
||||
label={t('gallery.galleryImageSize')}
|
||||
@ -102,11 +88,6 @@ const GallerySettingsPopover = () => {
|
||||
isChecked={shouldAutoSwitch}
|
||||
onChange={handleChangeAutoSwitch}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Show Delete Button"
|
||||
isChecked={shouldShowDeleteButton}
|
||||
onChange={handleChangeShowDeleteButton}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('gallery.autoAssignBoardOnClick')}
|
||||
isChecked={autoAssignBoardOnClick}
|
||||
|
@ -5,13 +5,21 @@ import {
|
||||
isModalOpenChanged,
|
||||
} from 'features/changeBoardModal/store/slice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import {
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from '../../../../services/api/endpoints/images';
|
||||
|
||||
const MultipleSelectionMenuItems = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selection = useAppSelector((state) => state.gallery.selection);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const handleChangeBoard = useCallback(() => {
|
||||
dispatch(imagesToChangeSelected(selection));
|
||||
dispatch(isModalOpenChanged(true));
|
||||
@ -21,8 +29,37 @@ const MultipleSelectionMenuItems = () => {
|
||||
dispatch(imagesToDeleteSelected(selection));
|
||||
}, [dispatch, selection]);
|
||||
|
||||
const handleStarSelection = useCallback(() => {
|
||||
starImages({ imageDTOs: selection });
|
||||
}, [starImages, selection]);
|
||||
|
||||
const handleUnstarSelection = useCallback(() => {
|
||||
unstarImages({ imageDTOs: selection });
|
||||
}, [unstarImages, selection]);
|
||||
|
||||
const areAllStarred = useMemo(() => {
|
||||
return selection.every((img) => img.starred);
|
||||
}, [selection]);
|
||||
|
||||
const areAllUnstarred = useMemo(() => {
|
||||
return selection.every((img) => !img.starred);
|
||||
}, [selection]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{areAllStarred && (
|
||||
<MenuItem
|
||||
icon={<MdStarBorder />}
|
||||
onClickCapture={handleUnstarSelection}
|
||||
>
|
||||
Unstar All
|
||||
</MenuItem>
|
||||
)}
|
||||
{(areAllUnstarred || (!areAllStarred && !areAllUnstarred)) && (
|
||||
<MenuItem icon={<MdStar />} onClickCapture={handleStarSelection}>
|
||||
Star All
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
||||
Change Board
|
||||
</MenuItem>
|
||||
|
@ -29,10 +29,15 @@ import {
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import {
|
||||
useGetImageMetadataQuery,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
@ -59,6 +64,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
: debouncedMetadataQueryArg ?? skipToken
|
||||
);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||
useCopyImageToClipboard();
|
||||
|
||||
@ -127,6 +135,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
copyImageToClipboard(imageDTO.image_url);
|
||||
}, [copyImageToClipboard, imageDTO.image_url]);
|
||||
|
||||
const handleStarImage = useCallback(() => {
|
||||
if (imageDTO) starImages({ imageDTOs: [imageDTO] });
|
||||
}, [starImages, imageDTO]);
|
||||
|
||||
const handleUnstarImage = useCallback(() => {
|
||||
if (imageDTO) unstarImages({ imageDTOs: [imageDTO] });
|
||||
}, [unstarImages, imageDTO]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem
|
||||
@ -196,6 +212,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
||||
Change Board
|
||||
</MenuItem>
|
||||
{imageDTO.starred ? (
|
||||
<MenuItem icon={<MdStar />} onClickCapture={handleUnstarImage}>
|
||||
Unstar Image
|
||||
</MenuItem>
|
||||
) : (
|
||||
<MenuItem icon={<MdStarBorder />} onClickCapture={handleStarImage}>
|
||||
Star Image
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
|
@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
|
||||
|
||||
return (
|
||||
<VStack
|
||||
layerStyle="first"
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
borderRadius: 'base',
|
||||
p: 2,
|
||||
}}
|
||||
>
|
||||
<Box sx={{ w: 'full' }}>
|
||||
|
@ -1,17 +1,23 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import {
|
||||
ImageDTOsDraggableData,
|
||||
ImageDraggableData,
|
||||
TypesafeDraggableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
||||
} from 'features/dnd/types';
|
||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import IAIDndImageIcon from '../../../../common/components/IAIDndImageIcon';
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageName: string;
|
||||
@ -21,9 +27,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { imageName } = props;
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||
const shouldShowDeleteButton = useAppSelector(
|
||||
(state) => state.gallery.shouldShowDeleteButton
|
||||
);
|
||||
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||
|
||||
const { handleClick, isSelected, selection, selectionCount } =
|
||||
useMultiselect(imageDTO);
|
||||
@ -59,6 +63,35 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
}
|
||||
}, [imageDTO, selection, selectionCount]);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const toggleStarredState = useCallback(() => {
|
||||
if (imageDTO) {
|
||||
if (imageDTO.starred) {
|
||||
unstarImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
starImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
}
|
||||
}, [starImages, unstarImages, imageDTO]);
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
}, []);
|
||||
|
||||
const handleMouseOut = useCallback(() => {
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO?.starred) return <MdStar size="20" />;
|
||||
if (!imageDTO?.starred && isHovered) return <MdStarBorder size="20" />;
|
||||
}, [imageDTO?.starred, isHovered]);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <IAIFillSkeleton />;
|
||||
}
|
||||
@ -80,16 +113,34 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
minSize={0}
|
||||
onClickReset={handleDelete}
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
thumbnail={true}
|
||||
withHoverOverlay
|
||||
resetIcon={<FaTrash />}
|
||||
resetTooltip="Delete image"
|
||||
withResetIcon={shouldShowDeleteButton} // removed bc it's too easy to accidentally delete images
|
||||
/>
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={imageDTO.starred ? 'Unstar' : 'Star'}
|
||||
/>
|
||||
|
||||
{isHovered && shift && (
|
||||
<IAIDndImageIcon
|
||||
onClick={handleDelete}
|
||||
icon={<FaTrash />}
|
||||
tooltip="Delete"
|
||||
styleOverrides={{
|
||||
bottom: 2,
|
||||
top: 'auto',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
</IAIDndImage>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
|
@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
||||
options: {
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'leave',
|
||||
autoHide: 'scroll',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
|
@ -1,26 +1,40 @@
|
||||
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { useMemo } from 'react';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { FaCopy, FaSave } from 'react-icons/fa';
|
||||
|
||||
type Props = {
|
||||
copyTooltip: string;
|
||||
label: string;
|
||||
jsonObject: object;
|
||||
fileName?: string;
|
||||
};
|
||||
|
||||
const ImageMetadataJSON = (props: Props) => {
|
||||
const { copyTooltip, jsonObject } = props;
|
||||
const { label, jsonObject, fileName } = props;
|
||||
const jsonString = useMemo(
|
||||
() => JSON.stringify(jsonObject, null, 2),
|
||||
[jsonObject]
|
||||
);
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(jsonString);
|
||||
}, [jsonString]);
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
const blob = new Blob([jsonString]);
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(blob);
|
||||
a.download = `${fileName || label}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
a.remove();
|
||||
}, [jsonString, label, fileName]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
sx={{
|
||||
borderRadius: 'base',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
flexGrow: 1,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
@ -36,6 +50,7 @@ const ImageMetadataJSON = (props: Props) => {
|
||||
bottom: 0,
|
||||
overflow: 'auto',
|
||||
p: 4,
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent
|
||||
@ -44,7 +59,7 @@ const ImageMetadataJSON = (props: Props) => {
|
||||
options={{
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'move',
|
||||
autoHide: 'scroll',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => {
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
||||
<Tooltip label={copyTooltip}>
|
||||
<Tooltip label={`Save ${label} JSON`}>
|
||||
<IconButton
|
||||
aria-label={copyTooltip}
|
||||
aria-label={`Save ${label} JSON`}
|
||||
icon={<FaSave />}
|
||||
variant="ghost"
|
||||
opacity={0.7}
|
||||
onClick={handleSave}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip label={`Copy ${label} JSON`}>
|
||||
<IconButton
|
||||
aria-label={`Copy ${label} JSON`}
|
||||
icon={<FaCopy />}
|
||||
variant="ghost"
|
||||
onClick={() => navigator.clipboard.writeText(jsonString)}
|
||||
opacity={0.7}
|
||||
onClick={handleCopy}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
@ -10,7 +10,8 @@ import {
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo } from 'react';
|
||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
const metadata = currentData?.metadata;
|
||||
const graph = currentData?.graph;
|
||||
|
||||
const tabData = useMemo(() => {
|
||||
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
|
||||
|
||||
if (metadata) {
|
||||
_tabData.push({
|
||||
label: 'Core Metadata',
|
||||
data: metadata,
|
||||
copyTooltip: 'Copy Core Metadata JSON',
|
||||
});
|
||||
}
|
||||
|
||||
if (image) {
|
||||
_tabData.push({
|
||||
label: 'Image Details',
|
||||
data: image,
|
||||
copyTooltip: 'Copy Image Details JSON',
|
||||
});
|
||||
}
|
||||
|
||||
if (graph) {
|
||||
_tabData.push({
|
||||
label: 'Graph',
|
||||
data: graph,
|
||||
copyTooltip: 'Copy Graph JSON',
|
||||
});
|
||||
}
|
||||
return _tabData;
|
||||
}, [metadata, graph, image]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="first"
|
||||
sx={{
|
||||
padding: 4,
|
||||
gap: 1,
|
||||
flexDirection: 'column',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
backdropFilter: 'blur(20px)',
|
||||
bg: 'baseAlpha.200',
|
||||
_dark: {
|
||||
bg: 'blackAlpha.600',
|
||||
},
|
||||
borderRadius: 'base',
|
||||
position: 'absolute',
|
||||
overflow: 'hidden',
|
||||
@ -103,32 +71,33 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||
>
|
||||
<TabList>
|
||||
{tabData.map((tab) => (
|
||||
<Tab
|
||||
key={tab.label}
|
||||
sx={{
|
||||
borderTopRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
|
||||
{tab.label}
|
||||
</Text>
|
||||
</Tab>
|
||||
))}
|
||||
<Tab>Core Metadata</Tab>
|
||||
<Tab>Image Details</Tab>
|
||||
<Tab>Graph</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels sx={{ w: 'full', h: 'full' }}>
|
||||
{tabData.map((tab) => (
|
||||
<TabPanel
|
||||
key={tab.label}
|
||||
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
|
||||
>
|
||||
<ImageMetadataJSON
|
||||
jsonObject={tab.data}
|
||||
copyTooltip={tab.copyTooltip}
|
||||
/>
|
||||
</TabPanel>
|
||||
))}
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
{metadata ? (
|
||||
<ImageMetadataJSON jsonObject={metadata} label="Core Metadata" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No core metadata found" />
|
||||
)}
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
{image ? (
|
||||
<ImageMetadataJSON jsonObject={image} label="Image Details" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No image details found" />
|
||||
)}
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
{graph ? (
|
||||
<ImageMetadataJSON jsonObject={graph} label="Graph" />
|
||||
) : (
|
||||
<IAINoContentFallback label="No graph found" />
|
||||
)}
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
|
@ -13,7 +13,6 @@ export const initialGalleryState: GalleryState = {
|
||||
galleryImageMinimumWidth: 96,
|
||||
selectedBoardId: 'none',
|
||||
galleryView: 'images',
|
||||
shouldShowDeleteButton: false,
|
||||
boardSearchText: '',
|
||||
};
|
||||
|
||||
@ -50,9 +49,6 @@ export const gallerySlice = createSlice({
|
||||
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
||||
state.galleryView = action.payload;
|
||||
},
|
||||
shouldShowDeleteButtonChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowDeleteButton = action.payload;
|
||||
},
|
||||
boardSearchTextChanged: (state, action: PayloadAction<string>) => {
|
||||
state.boardSearchText = action.payload;
|
||||
},
|
||||
@ -93,7 +89,6 @@ export const {
|
||||
autoAddBoardIdChanged,
|
||||
galleryViewChanged,
|
||||
selectionChanged,
|
||||
shouldShowDeleteButtonChanged,
|
||||
boardSearchTextChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
|
@ -21,6 +21,5 @@ export type GalleryState = {
|
||||
galleryImageMinimumWidth: number;
|
||||
selectedBoardId: BoardId;
|
||||
galleryView: GalleryView;
|
||||
shouldShowDeleteButton: boolean;
|
||||
boardSearchText: string;
|
||||
};
|
||||
|
@ -9,30 +9,40 @@ import { map } from 'lodash-es';
|
||||
import { forwardRef, useCallback } from 'react';
|
||||
import 'reactflow/dist/style.css';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||
import { useBuildNodeData } from '../hooks/useBuildNodeData';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
|
||||
type NodeTemplate = {
|
||||
label: string;
|
||||
value: string;
|
||||
description: string;
|
||||
tags: string[];
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ nodes }) => {
|
||||
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
|
||||
const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
description: template.description,
|
||||
tags: template.tags,
|
||||
};
|
||||
});
|
||||
|
||||
data.push({
|
||||
label: 'Progress Image',
|
||||
value: 'progress_image',
|
||||
description: 'Displays the progress image in the Node Editor',
|
||||
value: 'current_image',
|
||||
description: 'Displays the current image in the Node Editor',
|
||||
tags: ['progress'],
|
||||
});
|
||||
|
||||
data.push({
|
||||
label: 'Notes',
|
||||
value: 'notes',
|
||||
description: 'Add notes about your workflow',
|
||||
tags: ['notes'],
|
||||
});
|
||||
|
||||
return { data };
|
||||
@ -44,7 +54,7 @@ const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { data } = useAppSelector(selector);
|
||||
|
||||
const buildInvocation = useBuildInvocation();
|
||||
const buildInvocation = useBuildNodeData();
|
||||
|
||||
const toaster = useAppToaster();
|
||||
|
||||
@ -89,11 +99,12 @@ const AddNodeMenu = () => {
|
||||
filter={(value, item: NodeTemplate) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.description.toLowerCase().includes(value.toLowerCase().trim())
|
||||
item.description.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.tags.includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
sx={{
|
||||
width: '18rem',
|
||||
width: '24rem',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
|
@ -0,0 +1,61 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
|
||||
import { FIELDS, colorTokenToCssVar } from '../types/constants';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ nodes }) => {
|
||||
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
|
||||
nodes;
|
||||
|
||||
const stroke =
|
||||
currentConnectionFieldType && shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
let className = 'react-flow__custom_connection-path';
|
||||
|
||||
if (shouldAnimateEdges) {
|
||||
className = className.concat(' animated');
|
||||
}
|
||||
|
||||
return {
|
||||
stroke,
|
||||
className,
|
||||
};
|
||||
});
|
||||
|
||||
export const CustomConnectionLine = ({
|
||||
fromX,
|
||||
fromY,
|
||||
fromPosition,
|
||||
toX,
|
||||
toY,
|
||||
toPosition,
|
||||
}: ConnectionLineComponentProps) => {
|
||||
const { stroke, className } = useAppSelector(selector);
|
||||
|
||||
const pathParams = {
|
||||
sourceX: fromX,
|
||||
sourceY: fromY,
|
||||
sourcePosition: fromPosition,
|
||||
targetX: toX,
|
||||
targetY: toY,
|
||||
targetPosition: toPosition,
|
||||
};
|
||||
|
||||
const [dAttr] = getBezierPath(pathParams);
|
||||
|
||||
return (
|
||||
<g>
|
||||
<path
|
||||
fill="none"
|
||||
stroke={stroke}
|
||||
strokeWidth={2}
|
||||
className={className}
|
||||
d={dAttr}
|
||||
style={{ opacity: 0.8 }}
|
||||
/>
|
||||
</g>
|
||||
);
|
||||
};
|
@ -0,0 +1,183 @@
|
||||
import { Badge, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
BaseEdge,
|
||||
EdgeLabelRenderer,
|
||||
EdgeProps,
|
||||
getBezierPath,
|
||||
} from 'reactflow';
|
||||
import { FIELDS, colorTokenToCssVar } from '../types/constants';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
const makeEdgeSelector = (
|
||||
source: string,
|
||||
sourceHandleId: string | null | undefined,
|
||||
target: string,
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createSelector(stateSelector, ({ nodes }) => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge =
|
||||
isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
||||
const sourceType = isInvocationToInvocationEdge
|
||||
? sourceNode?.data?.outputs[sourceHandleId || '']?.type
|
||||
: undefined;
|
||||
|
||||
const stroke =
|
||||
sourceType && nodes.shouldColorEdges
|
||||
? colorTokenToCssVar(FIELDS[sourceType].color)
|
||||
: colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
};
|
||||
});
|
||||
|
||||
const CollapsedEdge = ({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
data,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps<{ count: number }>) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[selected, source, sourceHandleId, target, targetHandleId]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate } = useAppSelector(selector);
|
||||
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
const { base500 } = useChakraThemeTokens();
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke: base500,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate
|
||||
? 'dashdraw 0.5s linear infinite'
|
||||
: undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}}
|
||||
/>
|
||||
{data?.count && data.count > 1 && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px,${labelY}px)`,
|
||||
}}
|
||||
className="nodrag nopan"
|
||||
>
|
||||
<Badge
|
||||
variant="solid"
|
||||
sx={{
|
||||
bg: 'base.500',
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
boxShadow: 'base',
|
||||
}}
|
||||
>
|
||||
{data.count}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const DefaultEdge = ({
|
||||
sourceX,
|
||||
sourceY,
|
||||
targetX,
|
||||
targetY,
|
||||
sourcePosition,
|
||||
targetPosition,
|
||||
markerEnd,
|
||||
selected,
|
||||
source,
|
||||
target,
|
||||
sourceHandleId,
|
||||
targetHandleId,
|
||||
}: EdgeProps) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
makeEdgeSelector(
|
||||
source,
|
||||
sourceHandleId,
|
||||
target,
|
||||
targetHandleId,
|
||||
selected
|
||||
),
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||
|
||||
const [edgePath] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
targetX,
|
||||
targetY,
|
||||
targetPosition,
|
||||
});
|
||||
|
||||
return (
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
markerEnd={markerEnd}
|
||||
style={{
|
||||
strokeWidth: isSelected ? 3 : 2,
|
||||
stroke,
|
||||
opacity: isSelected ? 0.8 : 0.5,
|
||||
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
|
||||
strokeDasharray: shouldAnimate ? 5 : 'none',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const edgeTypes = {
|
||||
collapsed: CollapsedEdge,
|
||||
default: DefaultEdge,
|
||||
};
|
@ -0,0 +1,9 @@
|
||||
import CurrentImageNode from './nodes/CurrentImageNode';
|
||||
import InvocationNodeWrapper from './nodes/InvocationNodeWrapper';
|
||||
import NotesNode from './nodes/NotesNode';
|
||||
|
||||
export const nodeTypes = {
|
||||
invocation: InvocationNodeWrapper,
|
||||
current_image: CurrentImageNode,
|
||||
notes: NotesNode,
|
||||
};
|
@ -1,64 +0,0 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties, memo } from 'react';
|
||||
import { Handle, Position, Connection, HandleType } from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
|
||||
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
|
||||
|
||||
const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: 0,
|
||||
};
|
||||
|
||||
const inputHandleStyles: CSSProperties = {
|
||||
left: '-1rem',
|
||||
};
|
||||
|
||||
const outputHandleStyles: CSSProperties = {
|
||||
right: '-0.5rem',
|
||||
};
|
||||
|
||||
// const requiredConnectionStyles: CSSProperties = {
|
||||
// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
|
||||
// };
|
||||
|
||||
type FieldHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputFieldTemplate | OutputFieldTemplate;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
handleType: HandleType;
|
||||
styles?: CSSProperties;
|
||||
};
|
||||
|
||||
const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { field, isValidConnection, handleType, styles } = props;
|
||||
const { name, type } = field;
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={type}
|
||||
placement={handleType === 'target' ? 'start' : 'end'}
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type={handleType}
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={handleType === 'target' ? Position.Left : Position.Right}
|
||||
style={{
|
||||
backgroundColor: FIELDS[type].colorCssVar,
|
||||
...styles,
|
||||
...handleBaseStyles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
|
||||
// ...connectionEventStyles,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FieldHandle);
|
@ -1,8 +1,8 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Tooltip, Badge, Flex } from '@chakra-ui/react';
|
||||
import { Badge, Flex, Tooltip } from '@chakra-ui/react';
|
||||
import { map } from 'lodash-es';
|
||||
import { FIELDS } from '../types/constants';
|
||||
import { memo } from 'react';
|
||||
import 'reactflow/dist/style.css';
|
||||
import { FIELDS } from '../types/constants';
|
||||
|
||||
const FieldTypeLegend = () => {
|
||||
return (
|
||||
@ -10,8 +10,14 @@ const FieldTypeLegend = () => {
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge
|
||||
colorScheme={color}
|
||||
sx={{ userSelect: 'none' }}
|
||||
sx={{
|
||||
userSelect: 'none',
|
||||
color:
|
||||
parseInt(color.split('.')[1] ?? '0', 10) < 500
|
||||
? 'base.800'
|
||||
: 'base.50',
|
||||
bg: color,
|
||||
}}
|
||||
textAlign="center"
|
||||
>
|
||||
{title}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useToken } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
@ -7,35 +7,51 @@ import {
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnInit,
|
||||
OnEdgesDelete,
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
OnSelectionChangeFunc,
|
||||
ProOptions,
|
||||
ReactFlow,
|
||||
} from 'reactflow';
|
||||
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||
import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
setEditorInstance,
|
||||
nodesDeleted,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
viewportChanged,
|
||||
} from '../store/nodesSlice';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
import ProgressImageNode from './ProgressImageNode';
|
||||
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
|
||||
import MinimapPanel from './panels/MinimapPanel';
|
||||
import TopCenterPanel from './panels/TopCenterPanel';
|
||||
import TopLeftPanel from './panels/TopLeftPanel';
|
||||
import TopRightPanel from './panels/TopRightPanel';
|
||||
import { CustomConnectionLine } from './CustomConnectionLine';
|
||||
import { edgeTypes } from './CustomEdges';
|
||||
import { nodeTypes } from './CustomNodes';
|
||||
import BottomLeftPanel from './editorPanels/BottomLeftPanel';
|
||||
import MinimapPanel from './editorPanels/MinimapPanel';
|
||||
import TopCenterPanel from './editorPanels/TopCenterPanel';
|
||||
import TopLeftPanel from './editorPanels/TopLeftPanel';
|
||||
import TopRightPanel from './editorPanels/TopRightPanel';
|
||||
|
||||
const nodeTypes = {
|
||||
invocation: InvocationComponent,
|
||||
progress_image: ProgressImageNode,
|
||||
};
|
||||
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
|
||||
const proOptions: ProOptions = { hideAttribution: true };
|
||||
|
||||
export const Flow = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
const nodes = useAppSelector((state) => state.nodes.nodes);
|
||||
const edges = useAppSelector((state) => state.nodes.edges);
|
||||
const viewport = useAppSelector((state) => state.nodes.viewport);
|
||||
const shouldSnapToGrid = useAppSelector(
|
||||
(state) => state.nodes.shouldSnapToGrid
|
||||
);
|
||||
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
const [borderRadius] = useToken('radii', ['base']);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => {
|
||||
@ -69,35 +85,66 @@ export const Flow = () => {
|
||||
dispatch(connectionEnded());
|
||||
}, [dispatch]);
|
||||
|
||||
const onInit: OnInit = useCallback(
|
||||
(v) => {
|
||||
dispatch(setEditorInstance(v));
|
||||
if (v) v.fitView();
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
(edges) => {
|
||||
dispatch(edgesDeleted(edges));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onNodesDelete: OnNodesDelete = useCallback(
|
||||
(nodes) => {
|
||||
dispatch(nodesDeleted(nodes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectionChange: OnSelectionChangeFunc = useCallback(
|
||||
({ nodes, edges }) => {
|
||||
dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : []));
|
||||
dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : []));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleMoveEnd: OnMoveEnd = useCallback(
|
||||
(e, viewport) => {
|
||||
dispatch(viewportChanged(viewport));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
defaultViewport={viewport}
|
||||
nodeTypes={nodeTypes}
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
onNodesDelete={onNodesDelete}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
onInit={onInit}
|
||||
defaultEdgeOptions={{
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
onMoveEnd={handleMoveEnd}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
isValidConnection={isValidConnection}
|
||||
minZoom={0.2}
|
||||
snapToGrid={shouldSnapToGrid}
|
||||
snapGrid={[25, 25]}
|
||||
connectionRadius={30}
|
||||
proOptions={proOptions}
|
||||
style={{ borderRadius }}
|
||||
>
|
||||
<TopLeftPanel />
|
||||
<TopCenterPanel />
|
||||
<TopRightPanel />
|
||||
<BottomLeftPanel />
|
||||
<Background />
|
||||
<MinimapPanel />
|
||||
<Background />
|
||||
</ReactFlow>
|
||||
);
|
||||
};
|
||||
|
@ -1,55 +0,0 @@
|
||||
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
|
||||
import { memo } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
|
||||
interface IAINodeHeaderProps {
|
||||
nodeId?: string;
|
||||
title?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
||||
const { nodeId, title, description } = props;
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
borderTopRadius: 'md',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
px: 2,
|
||||
py: 1,
|
||||
bg: 'base.100',
|
||||
_dark: { bg: 'base.900' },
|
||||
}}
|
||||
>
|
||||
<Tooltip label={nodeId}>
|
||||
<Heading
|
||||
size="xs"
|
||||
sx={{
|
||||
fontWeight: 600,
|
||||
color: 'base.900',
|
||||
_dark: { color: 'base.200' },
|
||||
}}
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
</Tooltip>
|
||||
<Tooltip label={description} placement="top" hasArrow shouldWrapChildren>
|
||||
<Icon
|
||||
sx={{
|
||||
h: 'min-content',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
as={FaInfoCircle}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAINodeHeader);
|
@ -1,149 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map } from 'lodash-es';
|
||||
import { ReactNode, memo, useCallback } from 'react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import InputFieldComponent from '../InputFieldComponent';
|
||||
|
||||
interface IAINodeInputProps {
|
||||
nodeId: string;
|
||||
|
||||
input: InputFieldValue;
|
||||
template?: InputFieldTemplate | undefined;
|
||||
connected: boolean;
|
||||
}
|
||||
|
||||
function IAINodeInput(props: IAINodeInputProps) {
|
||||
const { nodeId, input, template, connected } = props;
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box
|
||||
className="nopan"
|
||||
position="relative"
|
||||
borderColor={
|
||||
!template
|
||||
? 'error.400'
|
||||
: !connected &&
|
||||
['always', 'connectionOnly'].includes(
|
||||
String(template?.inputRequirement)
|
||||
) &&
|
||||
input.value === undefined
|
||||
? 'warning.400'
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<FormControl isDisabled={!template ? true : connected} pl={2}>
|
||||
{!template ? (
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown input: {input.name}</FormLabel>
|
||||
</HStack>
|
||||
) : (
|
||||
<>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<HStack>
|
||||
<Tooltip
|
||||
label={template?.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<FormLabel>{template?.title}</FormLabel>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
<InputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={input}
|
||||
template={template}
|
||||
/>
|
||||
</HStack>
|
||||
|
||||
{!['never', 'directOnly'].includes(
|
||||
template?.inputRequirement ?? ''
|
||||
) && (
|
||||
<FieldHandle
|
||||
nodeId={nodeId}
|
||||
field={template}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
interface IAINodeInputsProps {
|
||||
nodeId: string;
|
||||
template: InvocationTemplate;
|
||||
inputs: Record<string, InputFieldValue>;
|
||||
}
|
||||
|
||||
const IAINodeInputs = (props: IAINodeInputsProps) => {
|
||||
const { nodeId, template, inputs } = props;
|
||||
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const renderIAINodeInputs = useCallback(() => {
|
||||
const IAINodeInputsToRender: ReactNode[] = [];
|
||||
const inputSockets = map(inputs);
|
||||
|
||||
inputSockets.forEach((inputSocket, index) => {
|
||||
const inputTemplate = template.inputs[inputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
edges.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.target === nodeId &&
|
||||
connectedInput.targetHandle === inputSocket.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
if (index < inputSockets.length) {
|
||||
IAINodeInputsToRender.push(
|
||||
<Divider key={`${inputSocket.id}.divider`} />
|
||||
);
|
||||
}
|
||||
|
||||
IAINodeInputsToRender.push(
|
||||
<IAINodeInput
|
||||
key={inputSocket.id}
|
||||
nodeId={nodeId}
|
||||
input={inputSocket}
|
||||
template={inputTemplate}
|
||||
connected={isConnected}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex className="nopan" flexDir="column" gap={2} p={2}>
|
||||
{IAINodeInputsToRender}
|
||||
</Flex>
|
||||
);
|
||||
}, [edges, inputs, nodeId, template.inputs]);
|
||||
|
||||
return renderIAINodeInputs();
|
||||
};
|
||||
|
||||
export default memo(IAINodeInputs);
|
@ -1,97 +0,0 @@
|
||||
import {
|
||||
InvocationTemplate,
|
||||
OutputFieldTemplate,
|
||||
OutputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, ReactNode, useCallback } from 'react';
|
||||
import { map } from 'lodash-es';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
|
||||
interface IAINodeOutputProps {
|
||||
nodeId: string;
|
||||
output: OutputFieldValue;
|
||||
template?: OutputFieldTemplate | undefined;
|
||||
connected: boolean;
|
||||
}
|
||||
|
||||
function IAINodeOutput(props: IAINodeOutputProps) {
|
||||
const { nodeId, output, template, connected } = props;
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl isDisabled={!template ? true : connected} paddingRight={3}>
|
||||
{!template ? (
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel color="error.400">
|
||||
Unknown Output: {output.name}
|
||||
</FormLabel>
|
||||
</HStack>
|
||||
) : (
|
||||
<>
|
||||
<FormLabel textAlign="end" padding={1}>
|
||||
{template?.title}
|
||||
</FormLabel>
|
||||
<FieldHandle
|
||||
key={output.id}
|
||||
nodeId={nodeId}
|
||||
field={template}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="source"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
interface IAINodeOutputsProps {
|
||||
nodeId: string;
|
||||
template: InvocationTemplate;
|
||||
outputs: Record<string, OutputFieldValue>;
|
||||
}
|
||||
|
||||
const IAINodeOutputs = (props: IAINodeOutputsProps) => {
|
||||
const { nodeId, template, outputs } = props;
|
||||
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const renderIAINodeOutputs = useCallback(() => {
|
||||
const IAINodeOutputsToRender: ReactNode[] = [];
|
||||
const outputSockets = map(outputs);
|
||||
|
||||
outputSockets.forEach((outputSocket) => {
|
||||
const outputTemplate = template.outputs[outputSocket.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
edges.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.source === nodeId &&
|
||||
connectedInput.sourceHandle === outputSocket.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
IAINodeOutputsToRender.push(
|
||||
<IAINodeOutput
|
||||
key={outputSocket.id}
|
||||
nodeId={nodeId}
|
||||
output={outputSocket}
|
||||
template={outputTemplate}
|
||||
connected={isConnected}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return <Flex flexDir="column">{IAINodeOutputsToRender}</Flex>;
|
||||
}, [edges, nodeId, outputs, template.outputs]);
|
||||
|
||||
return renderIAINodeOutputs();
|
||||
};
|
||||
|
||||
export default memo(IAINodeOutputs);
|
@ -1,252 +0,0 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
template: InputFieldTemplate;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
const { nodeId, field, template } = props;
|
||||
const { type } = field;
|
||||
|
||||
if (type === 'string' && template.type === 'string') {
|
||||
return (
|
||||
<StringInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'boolean' && template.type === 'boolean') {
|
||||
return (
|
||||
<BooleanInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(type === 'integer' && template.type === 'integer') ||
|
||||
(type === 'float' && template.type === 'float')
|
||||
) {
|
||||
return (
|
||||
<NumberInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'enum' && template.type === 'enum') {
|
||||
return (
|
||||
<EnumInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'image' && template.type === 'image') {
|
||||
return (
|
||||
<ImageInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'latents' && template.type === 'latents') {
|
||||
return (
|
||||
<LatentsInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'conditioning' && template.type === 'conditioning') {
|
||||
return (
|
||||
<ConditioningInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'unet' && template.type === 'unet') {
|
||||
return (
|
||||
<UnetInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'clip' && template.type === 'clip') {
|
||||
return (
|
||||
<ClipInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae' && template.type === 'vae') {
|
||||
return (
|
||||
<VaeInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'control' && template.type === 'control') {
|
||||
return (
|
||||
<ControlInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'model' && template.type === 'model') {
|
||||
return (
|
||||
<ModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||
return (
|
||||
<RefinerModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'lora_model' && template.type === 'lora_model') {
|
||||
return (
|
||||
<LoRAModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||
return (
|
||||
<ControlNetModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'item' && template.type === 'item') {
|
||||
return (
|
||||
<ItemInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'color' && template.type === 'color') {
|
||||
return (
|
||||
<ColorInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'item' && template.type === 'item') {
|
||||
return (
|
||||
<ItemInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'image_collection' && template.type === 'image_collection') {
|
||||
return (
|
||||
<ImageCollectionInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
||||
|
||||
export default memo(InputFieldComponent);
|
@ -0,0 +1,84 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map, some } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import InputField from '../fields/InputField';
|
||||
import OutputField from '../fields/OutputField';
|
||||
import NodeFooter, { FOOTER_FIELDS } from './NodeFooter';
|
||||
import NodeHeader from './NodeHeader';
|
||||
import NodeWrapper from './NodeWrapper';
|
||||
|
||||
type Props = {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => {
|
||||
const { id: nodeId, data } = nodeProps;
|
||||
const { inputs, outputs, isOpen } = data;
|
||||
|
||||
const inputFields = useMemo(
|
||||
() => map(inputs).filter((i) => i.name !== 'is_intermediate'),
|
||||
[inputs]
|
||||
);
|
||||
const outputFields = useMemo(() => map(outputs), [outputs]);
|
||||
|
||||
const withFooter = useMemo(
|
||||
() => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)),
|
||||
[outputs]
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeProps={nodeProps}>
|
||||
<NodeHeader nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
layerStyle="nodeBody"
|
||||
className={'nopan'}
|
||||
sx={{
|
||||
cursor: 'auto',
|
||||
flexDirection: 'column',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
py: 1,
|
||||
gap: 1,
|
||||
borderBottomRadius: withFooter ? 0 : 'base',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
className="nopan"
|
||||
sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}
|
||||
>
|
||||
{outputFields.map((field) => (
|
||||
<OutputField
|
||||
key={`${nodeId}.${field.id}.input-field`}
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
/>
|
||||
))}
|
||||
{inputFields.map((field) => (
|
||||
<InputField
|
||||
key={`${nodeId}.${field.id}.input-field`}
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
field={field}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
{withFooter && (
|
||||
<NodeFooter nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</NodeWrapper>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationNode);
|
@ -0,0 +1,57 @@
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NodeData } from 'features/nodes/types/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { NodeProps, useUpdateNodeInternals } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeProps: NodeProps<NodeData>;
|
||||
}
|
||||
|
||||
const NodeCollapseButton = (props: Props) => {
|
||||
const { id: nodeId, isOpen } = props.nodeProps.data;
|
||||
const dispatch = useAppDispatch();
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen }));
|
||||
updateNodeInternals(nodeId);
|
||||
}, [dispatch, isOpen, nodeId, updateNodeInternals]);
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
className="nopan"
|
||||
onClick={handleClick}
|
||||
aria-label="Minimize"
|
||||
sx={{
|
||||
minW: 8,
|
||||
w: 8,
|
||||
h: 8,
|
||||
color: 'base.500',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
_hover: {
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
},
|
||||
}}
|
||||
variant="link"
|
||||
icon={
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeCollapseButton);
|
@ -0,0 +1,74 @@
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map } from 'lodash-es';
|
||||
import { CSSProperties, memo, useMemo } from 'react';
|
||||
import { Handle, NodeProps, Position } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
}
|
||||
|
||||
const NodeCollapsedHandles = (props: Props) => {
|
||||
const { data } = props.nodeProps;
|
||||
const { base400, base600 } = useChakraThemeTokens();
|
||||
const backgroundColor = useColorModeValue(base400, base600);
|
||||
|
||||
const dummyHandleStyles: CSSProperties = useMemo(
|
||||
() => ({
|
||||
borderWidth: 0,
|
||||
borderRadius: '3px',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
backgroundColor,
|
||||
zIndex: -1,
|
||||
}),
|
||||
[backgroundColor]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Handle
|
||||
type="target"
|
||||
id={`${data.id}-collapsed-target`}
|
||||
isConnectable={false}
|
||||
position={Position.Left}
|
||||
style={{ ...dummyHandleStyles, left: '-0.5rem' }}
|
||||
/>
|
||||
{map(data.inputs, (input) => (
|
||||
<Handle
|
||||
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
||||
type="target"
|
||||
id={input.name}
|
||||
isValidConnection={() => false}
|
||||
position={Position.Left}
|
||||
style={{ visibility: 'hidden' }}
|
||||
/>
|
||||
))}
|
||||
<Handle
|
||||
type="source"
|
||||
id={`${data.id}-collapsed-source`}
|
||||
isValidConnection={() => false}
|
||||
isConnectable={false}
|
||||
position={Position.Right}
|
||||
style={{ ...dummyHandleStyles, right: '-0.5rem' }}
|
||||
/>
|
||||
{map(data.outputs, (output) => (
|
||||
<Handle
|
||||
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
||||
type="source"
|
||||
id={output.name}
|
||||
isValidConnection={() => false}
|
||||
position={Position.Right}
|
||||
style={{ visibility: 'hidden' }}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeCollapsedHandles);
|
@ -0,0 +1,80 @@
|
||||
import {
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Spacer,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { some } from 'lodash-es';
|
||||
import { ChangeEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
export const IMAGE_FIELDS = ['ImageField', 'ImageCollection'];
|
||||
export const FOOTER_FIELDS = IMAGE_FIELDS;
|
||||
|
||||
type Props = {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const NodeFooter = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const hasImageOutput = useMemo(
|
||||
() =>
|
||||
some(nodeTemplate?.outputs, (output) =>
|
||||
IMAGE_FIELDS.includes(output.type)
|
||||
),
|
||||
[nodeTemplate?.outputs]
|
||||
);
|
||||
|
||||
const handleChangeIsIntermediate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldBooleanValueChanged({
|
||||
nodeId: nodeProps.data.id,
|
||||
fieldName: 'is_intermediate',
|
||||
value: !e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, nodeProps.data.id]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
layerStyle="nodeFooter"
|
||||
sx={{
|
||||
w: 'full',
|
||||
borderBottomRadius: 'base',
|
||||
px: 2,
|
||||
py: 0,
|
||||
h: 6,
|
||||
}}
|
||||
>
|
||||
<Spacer />
|
||||
{hasImageOutput && (
|
||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
||||
<Checkbox
|
||||
className="nopan"
|
||||
size="sm"
|
||||
onChange={handleChangeIsIntermediate}
|
||||
isChecked={!nodeProps.data.inputs['is_intermediate']?.value}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeFooter);
|
@ -0,0 +1,54 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
import NodeCollapseButton from '../Invocation/NodeCollapseButton';
|
||||
import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles';
|
||||
import NodeNotesEdit from '../Invocation/NodeNotesEdit';
|
||||
import NodeStatusIndicator from '../Invocation/NodeStatusIndicator';
|
||||
import NodeTitle from '../Invocation/NodeTitle';
|
||||
|
||||
type Props = {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
};
|
||||
|
||||
const NodeHeader = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const { isOpen } = nodeProps.data;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
sx={{
|
||||
borderTopRadius: 'base',
|
||||
borderBottomRadius: isOpen ? 0 : 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
h: 8,
|
||||
textAlign: 'center',
|
||||
fontWeight: 600,
|
||||
color: 'base.700',
|
||||
_dark: { color: 'base.200' },
|
||||
}}
|
||||
>
|
||||
<NodeCollapseButton nodeProps={nodeProps} />
|
||||
<NodeTitle nodeData={nodeProps.data} title={nodeTemplate.title} />
|
||||
<Flex alignItems="center">
|
||||
<NodeStatusIndicator nodeProps={nodeProps} />
|
||||
<NodeNotesEdit nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
</Flex>
|
||||
{!isOpen && (
|
||||
<NodeCollapsedHandles
|
||||
nodeProps={nodeProps}
|
||||
nodeTemplate={nodeTemplate}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeHeader);
|
@ -0,0 +1,113 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Icon,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Text,
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
interface Props {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
nodeTemplate: InvocationTemplate;
|
||||
}
|
||||
|
||||
const NodeNotesEdit = (props: Props) => {
|
||||
const { nodeProps, nodeTemplate } = props;
|
||||
const { data } = nodeProps;
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value }));
|
||||
},
|
||||
[data.id, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Tooltip
|
||||
label={
|
||||
nodeTemplate ? (
|
||||
<TooltipContent nodeProps={nodeProps} nodeTemplate={nodeTemplate} />
|
||||
) : undefined
|
||||
}
|
||||
placement="top"
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
onClick={onOpen}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: 8,
|
||||
h: 8,
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
as={FaInfoCircle}
|
||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
||||
/>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
|
||||
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>
|
||||
{data.label || nodeTemplate?.title || 'Unknown Node'}
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<IAITextarea
|
||||
value={data.notes}
|
||||
onChange={handleNotesChanged}
|
||||
rows={10}
|
||||
/>
|
||||
</FormControl>
|
||||
</ModalBody>
|
||||
<ModalFooter />
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeNotesEdit);
|
||||
|
||||
type TooltipContentProps = Props;
|
||||
|
||||
const TooltipContent = (props: TooltipContentProps) => {
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text sx={{ fontWeight: 600 }}>{props.nodeTemplate?.title}</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{props.nodeTemplate?.description}
|
||||
</Text>
|
||||
{props.nodeProps.data.notes && <Text>{props.nodeProps.data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import { NodeResizeControl, NodeResizerProps } from 'reactflow';
|
||||
|
||||
const IAINodeResizer = (props: NodeResizerProps) => {
|
||||
// this causes https://github.com/invoke-ai/InvokeAI/issues/4140
|
||||
// not using it for now
|
||||
|
||||
const NodeResizer = (props: NodeResizerProps) => {
|
||||
const { ...rest } = props;
|
||||
return (
|
||||
<NodeResizeControl
|
||||
@ -21,4 +24,4 @@ const IAINodeResizer = (props: NodeResizerProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAINodeResizer);
|
||||
export default memo(NodeResizer);
|
@ -0,0 +1,69 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { InvocationNodeData } from 'features/nodes/types/types';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { FaBars } from 'react-icons/fa';
|
||||
|
||||
interface Props {
|
||||
data: InvocationNodeData;
|
||||
}
|
||||
|
||||
const NodeSettings = (props: Props) => {
|
||||
const { data } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChangeIsIntermediate = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldBooleanValueChanged({
|
||||
nodeId: data.id,
|
||||
fieldName: 'is_intermediate',
|
||||
value: e.target.checked,
|
||||
})
|
||||
);
|
||||
},
|
||||
[data.id, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
className="nopan"
|
||||
aria-label="Node Settings"
|
||||
variant="link"
|
||||
sx={{
|
||||
minW: 8,
|
||||
color: 'base.500',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
_hover: {
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
},
|
||||
}}
|
||||
icon={<FaBars />}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex sx={{ flexDir: 'column', gap: 4, w: 64 }}>
|
||||
<IAISwitch
|
||||
label="Intermediate"
|
||||
isChecked={Boolean(data.inputs['is_intermediate']?.value)}
|
||||
onChange={handleChangeIsIntermediate}
|
||||
helperText="The outputs of intermediate nodes are considered temporary objects. Intermediate images are not added to the gallery."
|
||||
/>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeSettings);
|
@ -0,0 +1,185 @@
|
||||
import {
|
||||
Badge,
|
||||
CircularProgress,
|
||||
Flex,
|
||||
Icon,
|
||||
Image,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InvocationNodeData,
|
||||
NodeExecutionState,
|
||||
NodeStatus,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
type Props = {
|
||||
nodeProps: NodeProps<InvocationNodeData>;
|
||||
};
|
||||
|
||||
const iconBoxSize = 3;
|
||||
const circleStyles = {
|
||||
circle: {
|
||||
transitionProperty: 'none',
|
||||
transitionDuration: '0s',
|
||||
},
|
||||
'.chakra-progress__track': { stroke: 'transparent' },
|
||||
};
|
||||
|
||||
const NodeStatusIndicator = (props: Props) => {
|
||||
const nodeId = props.nodeProps.data.id;
|
||||
const selectNodeExecutionState = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => nodes.nodeExecutionStates[nodeId]
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
|
||||
|
||||
if (!nodeExecutionState) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
label={<TooltipLabel nodeExecutionState={nodeExecutionState} />}
|
||||
placement="top"
|
||||
>
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
w: 5,
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'flex-end',
|
||||
}}
|
||||
>
|
||||
<StatusIcon nodeExecutionState={nodeExecutionState} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(NodeStatusIndicator);
|
||||
|
||||
type TooltipLabelProps = {
|
||||
nodeExecutionState: NodeExecutionState;
|
||||
};
|
||||
|
||||
const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => {
|
||||
const { status, progress, progressImage } = nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
return <Text>Pending</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
if (progressImage) {
|
||||
return (
|
||||
<Flex sx={{ pos: 'relative', pt: 1.5, pb: 0.5 }}>
|
||||
<Image
|
||||
src={progressImage.dataURL}
|
||||
sx={{ w: 32, h: 32, borderRadius: 'base', objectFit: 'contain' }}
|
||||
/>
|
||||
{progress !== null && (
|
||||
<Badge
|
||||
variant="solid"
|
||||
sx={{ pos: 'absolute', top: 2.5, insetInlineEnd: 1 }}
|
||||
>
|
||||
{Math.round(progress * 100)}%
|
||||
</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (progress !== null) {
|
||||
return <Text>In Progress ({Math.round(progress * 100)}%)</Text>;
|
||||
}
|
||||
|
||||
return <Text>In Progress</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
return <Text>Completed</Text>;
|
||||
}
|
||||
|
||||
if (status === NodeStatus.FAILED) {
|
||||
return <Text>nodeExecutionState.error</Text>;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
type StatusIconProps = {
|
||||
nodeExecutionState: NodeExecutionState;
|
||||
};
|
||||
|
||||
const StatusIcon = (props: StatusIconProps) => {
|
||||
const { progress, status } = props.nodeExecutionState;
|
||||
if (status === NodeStatus.PENDING) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaEllipsisH}
|
||||
sx={{
|
||||
boxSize: iconBoxSize,
|
||||
color: 'base.600',
|
||||
_dark: { color: 'base.300' },
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.IN_PROGRESS) {
|
||||
return progress === null ? (
|
||||
<CircularProgress
|
||||
isIndeterminate
|
||||
size="14px"
|
||||
color="base.500"
|
||||
thickness={14}
|
||||
sx={circleStyles}
|
||||
/>
|
||||
) : (
|
||||
<CircularProgress
|
||||
value={Math.round(progress * 100)}
|
||||
size="14px"
|
||||
color="base.500"
|
||||
thickness={14}
|
||||
sx={circleStyles}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.COMPLETED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaCheck}
|
||||
sx={{
|
||||
boxSize: iconBoxSize,
|
||||
color: 'ok.600',
|
||||
_dark: { color: 'ok.300' },
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (status === NodeStatus.FAILED) {
|
||||
return (
|
||||
<Icon
|
||||
as={FaExclamation}
|
||||
sx={{
|
||||
boxSize: iconBoxSize,
|
||||
color: 'error.600',
|
||||
_dark: { color: 'error.300' },
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
};
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user