diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 21fd09e051..692dc7fd21 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -5,7 +5,7 @@ from PIL import Image from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse from fastapi.routing import APIRouter -from pydantic import BaseModel +from pydantic import BaseModel, Field from invokeai.app.invocations.metadata import ImageMetadata from invokeai.app.models.image import ImageCategory, ResourceOrigin @@ -19,6 +19,7 @@ from ..dependencies import ApiDependencies images_router = APIRouter(prefix="/v1/images", tags=["images"]) + # images are immutable; set a high max-age IMAGE_MAX_AGE = 31536000 @@ -286,3 +287,41 @@ async def delete_images_from_list( return DeleteImagesFromListResult(deleted_images=deleted_images) except Exception as e: raise HTTPException(status_code=500, detail="Failed to delete images") + + +class ImagesUpdatedFromListResult(BaseModel): + updated_image_names: list[str] = Field(description="The image names that were updated") + + +@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult) +async def star_images_in_list( + image_names: list[str] = Body(description="The list of names of images to star", embed=True), +) -> ImagesUpdatedFromListResult: + try: + updated_image_names: list[str] = [] + for image_name in image_names: + try: + ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True)) + updated_image_names.append(image_name) + except: + pass + return ImagesUpdatedFromListResult(updated_image_names=updated_image_names) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to star images") + + +@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult) +async def unstar_images_in_list( + image_names: list[str] = Body(description="The list of names of images to unstar", embed=True), +) -> ImagesUpdatedFromListResult: + try: + updated_image_names: list[str] = [] + for image_name in image_names: + try: + ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False)) + updated_image_names.append(image_name) + except: + pass + return ImagesUpdatedFromListResult(updated_image_names=updated_image_names) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to unstar images") diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6b875d37ce..20b2781ef0 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -38,7 +38,7 @@ import mimetypes from .api.dependencies import ApiDependencies from .api.routers import sessions, models, images, boards, board_images, app_info from .api.sockets import SocketIO -from .invocations.baseinvocation import BaseInvocation +from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase import torch @@ -134,6 +134,11 @@ def custom_openapi(): # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] + # Add Node Editor UI helper schemas + ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/") + for schema_key, output_schema in ui_config_schemas["definitions"].items(): + openapi_schema["components"]["schemas"][schema_key] = output_schema + # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: invoker_name = invoker.__name__ diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 758ab2e787..363fa357ae 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -3,15 +3,366 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import Enum from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Mapping, + Optional, + Type, + TypeVar, + Union, + get_args, + get_type_hints, +) -from pydantic import BaseConfig, BaseModel, Field +from pydantic import BaseModel, Field +from pydantic.fields import Undefined +from pydantic.typing import NoArgAnyCallable if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + core_metadata = "Optional core metadata to be written to image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + + +class Input(str, Enum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class UIType(str, Enum): + """ + Type hints for the UI. + If a field should be provided a data type that does not exactly match the python type of the field, \ + use this to provide the type that should be used instead. See the node development docs for detail \ + on adding a new field type, which involves client-side changes. + """ + + # region Primitives + Integer = "integer" + Float = "float" + Boolean = "boolean" + String = "string" + Array = "array" + Image = "ImageField" + Latents = "LatentsField" + Conditioning = "ConditioningField" + Control = "ControlField" + Color = "ColorField" + ImageCollection = "ImageCollection" + ConditioningCollection = "ConditioningCollection" + ColorCollection = "ColorCollection" + LatentsCollection = "LatentsCollection" + IntegerCollection = "IntegerCollection" + FloatCollection = "FloatCollection" + StringCollection = "StringCollection" + BooleanCollection = "BooleanCollection" + # endregion + + # region Models + MainModel = "MainModelField" + SDXLMainModel = "SDXLMainModelField" + SDXLRefinerModel = "SDXLRefinerModelField" + ONNXModel = "ONNXModelField" + VaeModel = "VaeModelField" + LoRAModel = "LoRAModelField" + ControlNetModel = "ControlNetModelField" + UNet = "UNetField" + Vae = "VaeField" + CLIP = "ClipField" + # endregion + + # region Iterate/Collect + Collection = "Collection" + CollectionItem = "CollectionItem" + # endregion + + # region Misc + FilePath = "FilePath" + Enum = "enum" + # endregion + + +class UIComponent(str, Enum): + """ + The type of UI component to use for a field, used to override the default components, which are \ + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class _InputField(BaseModel): + """ + *DO NOT USE* + This helper class is used to tell the client about our custom field attributes via OpenAPI + schema generation, and Typescript type generation from that schema. It serves no functional + purpose in the backend. + """ + + input: Input + ui_hidden: bool + ui_type: Optional[UIType] + ui_component: Optional[UIComponent] + + +class _OutputField(BaseModel): + """ + *DO NOT USE* + This helper class is used to tell the client about our custom field attributes via OpenAPI + schema generation, and Typescript type generation from that schema. It serves no functional + purpose in the backend. + """ + + ui_hidden: bool + ui_type: Optional[UIType] + + +def InputField( + *args: Any, + default: Any = Undefined, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + allow_inf_nan: Optional[bool] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + input: Input = Input.Any, + ui_type: Optional[UIType] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + **kwargs: Any, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + """ + return Field( + *args, + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + **kwargs, + ) + + +def OutputField( + *args: Any, + default: Any = Undefined, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + allow_inf_nan: Optional[bool] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + ui_type: Optional[UIType] = None, + ui_hidden: bool = False, + **kwargs: Any, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + """ + return Field( + *args, + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + ui_type=ui_type, + ui_hidden=ui_hidden, + **kwargs, + ) + + +class UIConfigBase(BaseModel): + """ + Provides additional node configuration to the UI. + This is used internally by the @tags and @title decorator logic. You probably want to use those + decorators, though you may add this class to a node definition to specify the title and tags. + """ + + tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI") + title: Optional[str] = Field(default=None, description="The display name of the node") + + class InvocationContext: services: InvocationServices graph_execution_state_id: str @@ -39,6 +390,20 @@ class BaseInvocationOutput(BaseModel): return tuple(subclasses) +class RequiredConnectionException(Exception): + """Raised when an field which requires a connection did not receive a value.""" + + def __init__(self, node_id: str, field_name: str): + super().__init__(f"Node {node_id} missing connections for field {field_name}") + + +class MissingInputException(Exception): + """Raised when an field which requires some input, but did not receive a value.""" + + def __init__(self, node_id: str, field_name: str): + super().__init__(f"Node {node_id} missing value or connection for field {field_name}") + + class BaseInvocation(ABC, BaseModel): """A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers. @@ -76,70 +441,81 @@ class BaseInvocation(ABC, BaseModel): def get_output_type(cls): return signature(cls.invoke).return_annotation + class Config: + @staticmethod + def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + uiconfig = getattr(model_class, "UIConfig", None) + if uiconfig and hasattr(uiconfig, "title"): + schema["title"] = uiconfig.title + if uiconfig and hasattr(uiconfig, "tags"): + schema["tags"] = uiconfig.tags + @abstractmethod def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - # fmt: off - id: str = Field(description="The id of this node. Must be unique among all nodes.") - is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - # fmt: on + def __init__(self, **data): + # nodes may have required fields, that can accept input from connections + # on instantiation of the model, we need to exclude these from validation + restore = dict() + try: + field_names = list(self.__fields__.keys()) + for field_name in field_names: + # if the field is required and may get its value from a connection, exclude it from validation + field = self.__fields__[field_name] + _input = field.field_info.extra.get("input", None) + if _input in [Input.Connection, Input.Any] and field.required: + if field_name not in data: + restore[field_name] = self.__fields__.pop(field_name) + # instantiate the node, which will validate the data + super().__init__(**data) + finally: + # restore the removed fields + for field_name, field in restore.items(): + self.__fields__[field_name] = field + + def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + for field_name, field in self.__fields__.items(): + _input = field.field_info.extra.get("input", None) + if field.required and not hasattr(self, field_name): + if _input == Input.Connection: + raise RequiredConnectionException(self.__fields__["type"].default, field_name) + elif _input == Input.Any: + raise MissingInputException(self.__fields__["type"].default, field_name) + return self.invoke(context) + + id: str = InputField(description="The id of this node. Must be unique among all nodes.") + is_intermediate: bool = InputField( + default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct + ) + UIConfig: ClassVar[Type[UIConfigBase]] -# TODO: figure out a better way to provide these hints -# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False` -class UIConfig(TypedDict, total=False): - type_hints: Dict[ - str, - Literal[ - "integer", - "float", - "boolean", - "string", - "enum", - "image", - "latents", - "model", - "control", - "image_collection", - "vae_model", - "lora_model", - ], - ] - tags: List[str] - title: str +T = TypeVar("T", bound=BaseInvocation) -class CustomisedSchemaExtra(TypedDict): - ui: UIConfig +def title(title: str) -> Callable[[Type[T]], Type[T]]: + """Adds a title to the invocation. Use this to override the default title generation, which is based on the class name.""" + + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.title = title + return cls + + return wrapper -class InvocationConfig(BaseConfig): - """Customizes pydantic's BaseModel.Config class for use by Invocations. +def tags(*tags: str) -> Callable[[Type[T]], Type[T]]: + """Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI.""" - Provide `schema_extra` a `ui` dict to add hints for generated UIs. + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.tags = list(tags) + return cls - `tags` - - A list of strings, used to categorise invocations. - - `type_hints` - - A dict of field types which override the types in the invocation definition. - - Each key should be the name of one of the invocation's fields. - - Each value should be one of the valid types: - - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model` - - ```python - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["stable-diffusion", "image"], - "type_hints": { - "initial_image": "image", - }, - }, - } - ``` - """ - - schema_extra: CustomisedSchemaExtra + return wrapper diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 01c003da96..fc7aec3e01 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -3,58 +3,25 @@ from typing import Literal import numpy as np -from pydantic import Field, validator +from pydantic import validator -from invokeai.app.models.image import ImageField +from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig - - -class IntCollectionOutput(BaseInvocationOutput): - """A collection of integers""" - - type: Literal["int_collection"] = "int_collection" - - # Outputs - collection: list[int] = Field(default=[], description="The int collection") - - -class FloatCollectionOutput(BaseInvocationOutput): - """A collection of floats""" - - type: Literal["float_collection"] = "float_collection" - - # Outputs - collection: list[float] = Field(default=[], description="The float collection") - - -class ImageCollectionOutput(BaseInvocationOutput): - """A collection of images""" - - type: Literal["image_collection"] = "image_collection" - - # Outputs - collection: list[ImageField] = Field(default=[], description="The output images") - - class Config: - schema_extra = {"required": ["type", "collection"]} +from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title +@title("Integer Range") +@tags("collection", "integer", "range") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" type: Literal["range"] = "range" # Inputs - start: int = Field(default=0, description="The start of the range") - stop: int = Field(default=10, description="The stop of the range") - step: int = Field(default=1, description="The step of the range") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Range", "tags": ["range", "integer", "collection"]}, - } + start: int = InputField(default=0, description="The start of the range") + stop: int = InputField(default=10, description="The stop of the range") + step: int = InputField(default=1, description="The step of the range") @validator("stop") def stop_gt_start(cls, v, values): @@ -62,76 +29,44 @@ class RangeInvocation(BaseInvocation): raise ValueError("stop must be greater than start") return v - def invoke(self, context: InvocationContext) -> IntCollectionOutput: - return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) +@title("Integer Range of Size") +@tags("range", "integer", "size", "collection") class RangeOfSizeInvocation(BaseInvocation): """Creates a range from start to start + size with step""" type: Literal["range_of_size"] = "range_of_size" # Inputs - start: int = Field(default=0, description="The start of the range") - size: int = Field(default=1, description="The number of values") - step: int = Field(default=1, description="The step of the range") + start: int = InputField(default=0, description="The start of the range") + size: int = InputField(default=1, description="The number of values") + step: int = InputField(default=1, description="The step of the range") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]}, - } - - def invoke(self, context: InvocationContext) -> IntCollectionOutput: - return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) +@title("Random Range") +@tags("range", "integer", "random", "collection") class RandomRangeInvocation(BaseInvocation): """Creates a collection of random numbers""" type: Literal["random_range"] = "random_range" # Inputs - low: int = Field(default=0, description="The inclusive low value") - high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") - size: int = Field(default=1, description="The number of values to generate") - seed: int = Field( + low: int = InputField(default=0, description="The inclusive low value") + high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") + size: int = InputField(default=1, description="The number of values to generate") + seed: int = InputField( ge=0, le=SEED_MAX, description="The seed for the RNG (omit for random)", default_factory=get_random_seed, ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]}, - } - - def invoke(self, context: InvocationContext) -> IntCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) - return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) - - -class ImageCollectionInvocation(BaseInvocation): - """Load a collection of images and provide it as output.""" - - # fmt: off - type: Literal["image_collection"] = "image_collection" - - # Inputs - images: list[ImageField] = Field( - default=[], description="The image collection to load" - ) - # fmt: on - - def invoke(self, context: InvocationContext) -> ImageCollectionOutput: - return ImageCollectionOutput(collection=self.images) - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "type_hints": { - "title": "Image Collection", - "images": "image_collection", - } - }, - } + return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 86565366d9..0be33ce701 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,32 +1,35 @@ -from typing import Literal, Optional, Union, List, Annotated -from pydantic import BaseModel, Field import re - -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent -from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher +from dataclasses import dataclass +from typing import List, Literal, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from ...backend.util.devices import torch_dtype -from ...backend.model_management import ModelType -from ...backend.model_management.models import ModelNotFoundException +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput + +from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import ( + BasicConditioningInfo, + SDXLConditioningInfo, +) + +from ...backend.model_management import ModelPatcher, ModelType from ...backend.model_management.lora import ModelPatcher -from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from ...backend.model_management.models import ModelNotFoundException +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UIComponent, + tags, + title, +) from .model import ClipField -from dataclasses import dataclass - - -class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") - - class Config: - schema_extra = {"required": ["conditioning_name"]} @dataclass @@ -41,32 +44,26 @@ class ConditioningFieldData: # PerpNeg = "perp_neg" -class CompelOutput(BaseInvocationOutput): - """Compel parser output""" - - # fmt: off - type: Literal["compel_output"] = "compel_output" - - conditioning: ConditioningField = Field(default=None, description="Conditioning") - # fmt: on - - +@title("Compel Prompt") +@tags("prompt", "compel") class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" type: Literal["compel"] = "compel" - prompt: str = Field(default="", description="Prompt") - clip: ClipField = Field(None, description="Clip to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, - } + prompt: str = InputField( + default="", + description=FieldDescriptions.compel_prompt, + ui_component=UIComponent.Textarea, + ) + clip: ClipField = InputField( + title="CLIP", + description=FieldDescriptions.clip, + input=Input.Connection, + ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> CompelOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), context=context, @@ -149,7 +146,7 @@ class CompelInvocation(BaseInvocation): conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" context.services.latents.save(conditioning_name, conditioning_data) - return CompelOutput( + return ConditioningOutput( conditioning=ConditioningField( conditioning_name=conditioning_name, ), @@ -270,30 +267,26 @@ class SDXLPromptInvocationBase: return c, c_pooled, ec +@title("SDXL Compel Prompt") +@tags("sdxl", "compel", "prompt") class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt" - prompt: str = Field(default="", description="Prompt") - style: str = Field(default="", description="Style prompt") - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - target_width: int = Field(1024, description="") - target_height: int = Field(1024, description="") - clip: ClipField = Field(None, description="Clip to use") - clip2: ClipField = Field(None, description="Clip2 to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, - } + prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) + style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) + original_width: int = InputField(default=1024, description="") + original_height: int = InputField(default=1024, description="") + crop_top: int = InputField(default=0, description="") + crop_left: int = InputField(default=0, description="") + target_width: int = InputField(default=1024, description="") + target_height: int = InputField(default=1024, description="") + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) + clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> CompelOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -326,38 +319,32 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" context.services.latents.save(conditioning_name, conditioning_data) - return CompelOutput( + return ConditioningOutput( conditioning=ConditioningField( conditioning_name=conditioning_name, ), ) +@title("SDXL Refiner Compel Prompt") +@tags("sdxl", "compel", "prompt") class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt" - style: str = Field(default="", description="Style prompt") # TODO: ? - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - aesthetic_score: float = Field(6.0, description="") - clip2: ClipField = Field(None, description="Clip to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Refiner Prompt (Compel)", - "tags": ["prompt", "compel"], - "type_hints": {"model": "model"}, - }, - } + style: str = InputField( + default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea + ) # TODO: ? + original_width: int = InputField(default=1024, description="") + original_height: int = InputField(default=1024, description="") + crop_top: int = InputField(default=0, description="") + crop_left: int = InputField(default=0, description="") + aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic) + clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> CompelOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -380,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" context.services.latents.save(conditioning_name, conditioning_data) - return CompelOutput( + return ConditioningOutput( conditioning=ConditioningField( conditioning_name=conditioning_name, ), @@ -391,21 +378,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput): """Clip skip node output""" type: Literal["clip_skip_output"] = "clip_skip_output" - clip: ClipField = Field(None, description="Clip with skipped layers") + clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") +@title("CLIP Skip") +@tags("clipskip", "clip", "skip") class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" type: Literal["clip_skip"] = "clip_skip" - clip: ClipField = Field(None, description="Clip to use") - skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]}, - } + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") + skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index d2b2d44526..811f767dee 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -26,79 +26,31 @@ from controlnet_aux.util import HWC3, ade_palette from PIL import Image from pydantic import BaseModel, Field, validator +from invokeai.app.invocations.primitives import ImageField, ImageOutput + + from ...backend.model_management import BaseModelType, ModelType -from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from ..models.image import ImageOutput, PILInvocationConfig +from ..models.image import ImageCategory, ResourceOrigin +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UIType, + tags, + title, +) -CONTROLNET_DEFAULT_MODELS = [ - ########################################### - # lllyasviel sd v1.5, ControlNet v1.0 models - ############################################## - "lllyasviel/sd-controlnet-canny", - "lllyasviel/sd-controlnet-depth", - "lllyasviel/sd-controlnet-hed", - "lllyasviel/sd-controlnet-seg", - "lllyasviel/sd-controlnet-openpose", - "lllyasviel/sd-controlnet-scribble", - "lllyasviel/sd-controlnet-normal", - "lllyasviel/sd-controlnet-mlsd", - ############################################# - # lllyasviel sd v1.5, ControlNet v1.1 models - ############################################# - "lllyasviel/control_v11p_sd15_canny", - "lllyasviel/control_v11p_sd15_openpose", - "lllyasviel/control_v11p_sd15_seg", - # "lllyasviel/control_v11p_sd15_depth", # broken - "lllyasviel/control_v11f1p_sd15_depth", - "lllyasviel/control_v11p_sd15_normalbae", - "lllyasviel/control_v11p_sd15_scribble", - "lllyasviel/control_v11p_sd15_mlsd", - "lllyasviel/control_v11p_sd15_softedge", - "lllyasviel/control_v11p_sd15s2_lineart_anime", - "lllyasviel/control_v11p_sd15_lineart", - "lllyasviel/control_v11p_sd15_inpaint", - # "lllyasviel/control_v11u_sd15_tile", - # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile", - # so for now replace "lllyasviel/control_v11f1e_sd15_tile", - "lllyasviel/control_v11e_sd15_shuffle", - "lllyasviel/control_v11e_sd15_ip2p", - "lllyasviel/control_v11f1e_sd15_tile", - ################################################# - # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? - ################################################## - "thibaud/controlnet-sd21-openpose-diffusers", - "thibaud/controlnet-sd21-canny-diffusers", - "thibaud/controlnet-sd21-depth-diffusers", - "thibaud/controlnet-sd21-scribble-diffusers", - "thibaud/controlnet-sd21-hed-diffusers", - "thibaud/controlnet-sd21-zoedepth-diffusers", - "thibaud/controlnet-sd21-color-diffusers", - "thibaud/controlnet-sd21-openposev2-diffusers", - "thibaud/controlnet-sd21-lineart-diffusers", - "thibaud/controlnet-sd21-normalbae-diffusers", - "thibaud/controlnet-sd21-ade20k-diffusers", - ############################################## - # ControlNetMediaPipeface, ControlNet v1.1 - ############################################## - # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5 - # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg - # hacked t2l to split to model & subfolder if format is "model,subfolder" - "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5 - "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1? -] -CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] -CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] +CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_RESIZE_VALUES = Literal[ - tuple( - [ - "just_resize", - "crop_resize", - "fill_resize", - "just_resize_simple", - ] - ) + "just_resize", + "crop_resize", + "fill_resize", + "just_resize_simple", ] @@ -110,9 +62,8 @@ class ControlNetModelField(BaseModel): class ControlField(BaseModel): - image: ImageField = Field(default=None, description="The control image") - control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") - # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") + image: ImageField = Field(description="The control image") + control_model: ControlNetModelField = Field(description="The ControlNet model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -135,60 +86,39 @@ class ControlField(BaseModel): raise ValueError("Control weights must be within -1 to 2 range") return v - class Config: - schema_extra = { - "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"], - "ui": { - "type_hints": { - "control_weight": "float", - "control_model": "controlnet_model", - # "control_weight": "number", - } - }, - } - class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" - # fmt: off type: Literal["control_output"] = "control_output" - control: ControlField = Field(default=None, description="The control info") - # fmt: on + + # Outputs + control: ControlField = OutputField(description=FieldDescriptions.control) +@title("ControlNet") +@tags("controlnet") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" - # fmt: off type: Literal["controlnet"] = "controlnet" - # Inputs - image: ImageField = Field(default=None, description="The control image") - control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", - description="control model used") - control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - begin_step_percent: float = Field(default=0, ge=-1, le=2, - description="When the ControlNet is first applied (% of total steps)") - end_step_percent: float = Field(default=1, ge=0, le=1, - description="When the ControlNet is last applied (% of total steps)") - control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") - resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "ControlNet", - "tags": ["controlnet", "latents"], - "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number", - "control_weight": "float", - }, - }, - } + # Inputs + image: ImageField = InputField(description="The control image") + control_model: ControlNetModelField = InputField( + default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct + ) + control_weight: Union[float, List[float]] = InputField( + default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float + ) + begin_step_percent: float = InputField( + default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)" + ) + end_step_percent: float = InputField( + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used") + resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( @@ -204,19 +134,13 @@ class ControlNetInvocation(BaseInvocation): ) -class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): +class ImageProcessorInvocation(BaseInvocation): """Base class for invocations that preprocess images for ControlNet""" - # fmt: off type: Literal["image_processor"] = "image_processor" - # Inputs - image: ImageField = Field(default=None, description="The image to process") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Processor", "tags": ["image", "processor"]}, - } + # Inputs + image: ImageField = InputField(description="The image to process") def run_processor(self, image): # superclass just passes through image without processing @@ -255,20 +179,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): ) -class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Canny Processor") +@tags("controlnet", "canny") +class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" - # fmt: off type: Literal["canny_image_processor"] = "canny_image_processor" - # Input - low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)") - high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]}, - } + # Input + low_threshold: int = InputField( + default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" + ) + high_threshold: int = InputField( + default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)" + ) def run_processor(self, image): canny_processor = CannyDetector() @@ -276,23 +200,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi return processed_image -class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("HED (softedge) Processor") +@tags("controlnet", "hed", "softedge") +class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" - # fmt: off type: Literal["hed_image_processor"] = "hed_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # safe not supported in controlnet_aux v0.0.3 - # safe: bool = Field(default=False, description="whether to use safe mode") - scribble: bool = Field(default=False, description="Whether to use scribble mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + # safe not supported in controlnet_aux v0.0.3 + # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) + scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") @@ -307,21 +227,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig) return processed_image -class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Lineart Processor") +@tags("controlnet", "lineart") +class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" - # fmt: off type: Literal["lineart_image_processor"] = "lineart_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - coarse: bool = Field(default=False, description="Whether to use coarse mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + coarse: bool = InputField(default=False, description="Whether to use coarse mode") def run_processor(self, image): lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") @@ -331,23 +247,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon return processed_image -class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Lineart Anime Processor") +@tags("controlnet", "lineart", "anime") +class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" - # fmt: off type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Lineart Anime Processor", - "tags": ["controlnet", "lineart", "anime", "image", "processor"], - }, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") @@ -359,21 +268,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati return processed_image -class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Openpose Processor") +@tags("controlnet", "openpose", "pose") +class OpenposeImageProcessorInvocation(ImageProcessorInvocation): """Applies Openpose processing to image""" - # fmt: off type: Literal["openpose_image_processor"] = "openpose_image_processor" - # Inputs - hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode") - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]}, - } + # Inputs + hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode") + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") @@ -386,22 +291,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Midas (Depth) Processor") +@tags("controlnet", "midas", "depth") +class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" - # fmt: off type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" - # Inputs - a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") - bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`") - # depth_and_normal not supported in controlnet_aux v0.0.3 - # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]}, - } + # Inputs + a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") + bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") + # depth_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") def run_processor(self, image): midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") @@ -415,20 +316,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation return processed_image -class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Normal BAE Processor") +@tags("controlnet", "normal", "bae") +class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" - # fmt: off type: Literal["normalbae_image_processor"] = "normalbae_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") @@ -438,22 +335,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC return processed_image -class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("MLSD Processor") +@tags("controlnet", "mlsd") +class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" - # fmt: off type: Literal["mlsd_image_processor"] = "mlsd_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`") - thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") + thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") def run_processor(self, image): mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") @@ -467,22 +360,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig return processed_image -class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("PIDI Processor") +@tags("controlnet", "pidi") +class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" - # fmt: off type: Literal["pidi_image_processor"] = "pidi_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - safe: bool = Field(default=False, description="Whether to use safe mode") - scribble: bool = Field(default=False, description="Whether to use scribble mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) + scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") @@ -496,26 +385,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig return processed_image -class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Content Shuffle Processor") +@tags("controlnet", "contentshuffle") +class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" - # fmt: off type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter") - w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter") - f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Content Shuffle Processor", - "tags": ["controlnet", "contentshuffle", "image", "processor"], - }, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter") + w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter") + f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter") def run_processor(self, image): content_shuffle_processor = ContentShuffleDetector() @@ -531,17 +413,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 -class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Zoe (Depth) Processor") +@tags("controlnet", "zoe", "depth") +class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - # fmt: off type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]}, - } def run_processor(self, image): zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") @@ -549,20 +426,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Mediapipe Face Processor") +@tags("controlnet", "mediapipe", "face") +class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" - # fmt: off type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" - # Inputs - max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect") - min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]}, - } + # Inputs + max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") + min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") def run_processor(self, image): # MediaPipeFaceDetector throws an error if image has alpha channel @@ -574,23 +447,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Leres (Depth) Processor") +@tags("controlnet", "leres", "depth") +class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" - # fmt: off type: Literal["leres_image_processor"] = "leres_image_processor" - # Inputs - thr_a: float = Field(default=0, description="Leres parameter `thr_a`") - thr_b: float = Field(default=0, description="Leres parameter `thr_b`") - boost: bool = Field(default=False, description="Whether to use boost mode") - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]}, - } + # Inputs + thr_a: float = InputField(default=0, description="Leres parameter `thr_a`") + thr_b: float = InputField(default=0, description="Leres parameter `thr_b`") + boost: bool = InputField(default=False, description="Whether to use boost mode") + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") @@ -605,21 +474,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi return processed_image -class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): - # fmt: off - type: Literal["tile_image_processor"] = "tile_image_processor" - # Inputs - #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile") - down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") - # fmt: on +@title("Tile Resample Processor") +@tags("controlnet", "tile") +class TileResamplerProcessorInvocation(ImageProcessorInvocation): + """Tile resampler processor""" - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Tile Resample Processor", - "tags": ["controlnet", "tile", "resample", "image", "processor"], - }, - } + type: Literal["tile_image_processor"] = "tile_image_processor" + + # Inputs + # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile") + down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") # tile_resample copied from sd-webui-controlnet/scripts/processor.py def tile_resample( @@ -648,20 +512,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Segment Anything Processor") +@tags("controlnet", "segmentanything") +class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" - # fmt: off type: Literal["segment_anything_processor"] = "segment_anything_processor" - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Segment Anything Processor", - "tags": ["controlnet", "segment", "anything", "sam", "image", "processor"], - }, - } def run_processor(self, image): # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index bd3a4adbe4..f318434211 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,40 +5,22 @@ from typing import Literal import cv2 as cv import numpy from PIL import Image, ImageOps -from pydantic import BaseModel, Field +from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput +from invokeai.app.models.image import ImageCategory, ResourceOrigin +from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title -class CvInvocationConfig(BaseModel): - """Helper class to provide all OpenCV invocations with additional config""" - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["cv", "image"], - }, - } - - -class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): +@title("OpenCV Inpaint") +@tags("opencv", "inpaint") +class CvInpaintInvocation(BaseInvocation): """Simple inpaint using opencv.""" - # fmt: off type: Literal["cv_inpaint"] = "cv_inpaint" # Inputs - image: ImageField = Field(default=None, description="The image to inpaint") - mask: ImageField = Field(default=None, description="The mask to use when inpainting") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]}, - } + image: ImageField = InputField(description="The image to inpaint") + mask: ImageField = InputField(description="The mask to use when inpainting") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 2c47020207..f4a1648196 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1,60 +1,31 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Optional import cv2 import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from pydantic import Field from invokeai.app.invocations.metadata import CoreMetadata +from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker -from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext - - -class LoadImageInvocation(BaseInvocation): - """Load an image and provide it as output.""" - - # fmt: off - type: Literal["load_image"] = "load_image" - - # Inputs - image: Optional[ImageField] = Field( - default=None, description="The image to load" - ) - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Load Image", "tags": ["image", "load"]}, - } - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image.width, - height=image.height, - ) +from ..models.image import ImageCategory, ResourceOrigin +from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title +@title("Show Image") +@tags("image") class ShowImageInvocation(BaseInvocation): """Displays a provided image, and passes it forward in the pipeline.""" + # Metadata type: Literal["show_image"] = "show_image" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to show") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Show Image", "tags": ["image", "show"]}, - } + image: ImageField = InputField(description="The image to show") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -70,24 +41,20 @@ class ShowImageInvocation(BaseInvocation): ) -class ImageCropInvocation(BaseInvocation, PILInvocationConfig): +@title("Crop Image") +@tags("image", "crop") +class ImageCropInvocation(BaseInvocation): """Crops an image to a specified box. The box can be outside of the image.""" - # fmt: off + # Metadata type: Literal["img_crop"] = "img_crop" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to crop") - x: int = Field(default=0, description="The left x coordinate of the crop rectangle") - y: int = Field(default=0, description="The top y coordinate of the crop rectangle") - width: int = Field(default=512, gt=0, description="The width of the crop rectangle") - height: int = Field(default=512, gt=0, description="The height of the crop rectangle") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Crop Image", "tags": ["image", "crop"]}, - } + image: ImageField = InputField(description="The image to crop") + x: int = InputField(default=0, description="The left x coordinate of the crop rectangle") + y: int = InputField(default=0, description="The top y coordinate of the crop rectangle") + width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") + height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -111,24 +78,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): ) -class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): +@title("Paste Image") +@tags("image", "paste") +class ImagePasteInvocation(BaseInvocation): """Pastes an image into another image.""" - # fmt: off + # Metadata type: Literal["img_paste"] = "img_paste" # Inputs - base_image: Optional[ImageField] = Field(default=None, description="The base image") - image: Optional[ImageField] = Field(default=None, description="The image to paste") - mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") - x: int = Field(default=0, description="The left x coordinate at which to paste the image") - y: int = Field(default=0, description="The top y coordinate at which to paste the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Paste Image", "tags": ["image", "paste"]}, - } + base_image: ImageField = InputField(description="The base image") + image: ImageField = InputField(description="The image to paste") + mask: Optional[ImageField] = InputField( + default=None, + description="The mask to use when pasting", + ) + x: int = InputField(default=0, description="The left x coordinate at which to paste the image") + y: int = InputField(default=0, description="The top y coordinate at which to paste the image") def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image(self.base_image.image_name) @@ -164,23 +130,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): ) -class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): +@title("Mask from Alpha") +@tags("image", "mask") +class MaskFromAlphaInvocation(BaseInvocation): """Extracts the alpha channel of an image as a mask.""" - # fmt: off + # Metadata type: Literal["tomask"] = "tomask" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to create the mask from") - invert: bool = Field(default=False, description="Whether or not to invert the mask") - # fmt: on + image: ImageField = InputField(description="The image to create the mask from") + invert: bool = InputField(default=False, description="Whether or not to invert the mask") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]}, - } - - def invoke(self, context: InvocationContext) -> MaskOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) image_mask = image.split()[-1] @@ -196,28 +158,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): is_intermediate=self.is_intermediate, ) - return MaskOutput( - mask=ImageField(image_name=image_dto.image_name), + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) -class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): +@title("Multiply Images") +@tags("image", "multiply") +class ImageMultiplyInvocation(BaseInvocation): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" - # fmt: off + # Metadata type: Literal["img_mul"] = "img_mul" # Inputs - image1: Optional[ImageField] = Field(default=None, description="The first image to multiply") - image2: Optional[ImageField] = Field(default=None, description="The second image to multiply") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Multiply Images", "tags": ["image", "multiply"]}, - } + image1: ImageField = InputField(description="The first image to multiply") + image2: ImageField = InputField(description="The second image to multiply") def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.services.images.get_pil_image(self.image1.image_name) @@ -244,21 +202,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): IMAGE_CHANNELS = Literal["A", "R", "G", "B"] -class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): +@title("Extract Image Channel") +@tags("image", "channel") +class ImageChannelInvocation(BaseInvocation): """Gets a channel from an image.""" - # fmt: off + # Metadata type: Literal["img_chan"] = "img_chan" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to get the channel from") - channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Channel", "tags": ["image", "channel"]}, - } + image: ImageField = InputField(description="The image to get the channel from") + channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -284,21 +238,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] -class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): +@title("Convert Image Mode") +@tags("image", "convert") +class ImageConvertInvocation(BaseInvocation): """Converts an image to a different mode.""" - # fmt: off + # Metadata type: Literal["img_conv"] = "img_conv" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to convert") - mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Convert Image", "tags": ["image", "convert"]}, - } + image: ImageField = InputField(description="The image to convert") + mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -321,22 +271,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): +@title("Blur Image") +@tags("image", "blur") +class ImageBlurInvocation(BaseInvocation): """Blurs an image""" - # fmt: off + # Metadata type: Literal["img_blur"] = "img_blur" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to blur") - radius: float = Field(default=8.0, ge=0, description="The blur radius") - blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Blur Image", "tags": ["image", "blur"]}, - } + image: ImageField = InputField(description="The image to blur") + radius: float = InputField(default=8.0, ge=0, description="The blur radius") + # Metadata + blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -382,23 +329,19 @@ PIL_RESAMPLING_MAP = { } -class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): +@title("Resize Image") +@tags("image", "resize") +class ImageResizeInvocation(BaseInvocation): """Resizes an image to specific dimensions""" - # fmt: off + # Metadata type: Literal["img_resize"] = "img_resize" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to resize") - width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Resize Image", "tags": ["image", "resize"]}, - } + image: ImageField = InputField(description="The image to resize") + width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") + resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -426,22 +369,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): +@title("Scale Image") +@tags("image", "scale") +class ImageScaleInvocation(BaseInvocation): """Scales an image by a factor""" - # fmt: off + # Metadata type: Literal["img_scale"] = "img_scale" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to scale") - scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image") - resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Scale Image", "tags": ["image", "scale"]}, - } + image: ImageField = InputField(description="The image to scale") + scale_factor: float = InputField( + default=2.0, + gt=0, + description="The factor by which to scale the image", + ) + resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -471,22 +414,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): +@title("Lerp Image") +@tags("image", "lerp") +class ImageLerpInvocation(BaseInvocation): """Linear interpolation of all pixels of an image""" - # fmt: off + # Metadata type: Literal["img_lerp"] = "img_lerp" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum output value") - max: int = Field(default=255, ge=0, le=255, description="The maximum output value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]}, - } + image: ImageField = InputField(description="The image to lerp") + min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") + max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -512,25 +451,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): +@title("Inverse Lerp Image") +@tags("image", "ilerp") +class ImageInverseLerpInvocation(BaseInvocation): """Inverse linear interpolation of all pixels of an image""" - # fmt: off + # Metadata type: Literal["img_ilerp"] = "img_ilerp" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum input value") - max: int = Field(default=255, ge=0, le=255, description="The maximum input value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Image Inverse Linear Interpolation", - "tags": ["image", "linear", "interpolation", "inverse"], - }, - } + image: ImageField = InputField(description="The image to lerp") + min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") + max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -556,21 +488,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): +@title("Blur NSFW Image") +@tags("image", "nsfw") +class ImageNSFWBlurInvocation(BaseInvocation): """Add blur to NSFW-flagged images""" - # fmt: off + # Metadata type: Literal["img_nsfw"] = "img_nsfw" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to check") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]}, - } + image: ImageField = InputField(description="The image to check") + metadata: Optional[CoreMetadata] = InputField( + default=None, description=FieldDescriptions.core_metadata, ui_hidden=True + ) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -607,22 +537,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): return caution.resize((caution.width // 2, caution.height // 2)) -class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): +@title("Add Invisible Watermark") +@tags("image", "watermark") +class ImageWatermarkInvocation(BaseInvocation): """Add an invisible watermark to an image""" - # fmt: off + # Metadata type: Literal["img_watermark"] = "img_watermark" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to check") - text: str = Field(default='InvokeAI', description="Watermark text") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]}, - } + image: ImageField = InputField(description="The image to check") + text: str = InputField(default="InvokeAI", description="Watermark text") + metadata: Optional[CoreMetadata] = InputField( + default=None, description=FieldDescriptions.core_metadata, ui_hidden=True + ) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -644,21 +572,23 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): ) -class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig): +@title("Mask Edge") +@tags("image", "mask", "inpaint") +class MaskEdgeInvocation(BaseInvocation): """Applies an edge mask to an image""" - # fmt: off type: Literal["mask_edge"] = "mask_edge" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to") - edge_size: int = Field(description="The size of the edge") - edge_blur: int = Field(description="The amount of blur on the edge") - low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection") - high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection") - # fmt: on + image: ImageField = InputField(description="The image to apply the mask to") + edge_size: int = InputField(description="The size of the edge") + edge_blur: int = InputField(description="The amount of blur on the edge") + low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection") + high_threshold: int = InputField( + description="Second threshold for the hysteresis procedure in Canny edge detection" + ) - def invoke(self, context: InvocationContext) -> MaskOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask = context.services.images.get_pil_image(self.image.image_name) npimg = numpy.asarray(mask, dtype=numpy.uint8) @@ -683,28 +613,23 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig): is_intermediate=self.is_intermediate, ) - return MaskOutput( - mask=ImageField(image_name=image_dto.image_name), + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) -class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): +@title("Combine Mask") +@tags("image", "mask", "multiply") +class MaskCombineInvocation(BaseInvocation): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" - # fmt: off type: Literal["mask_combine"] = "mask_combine" # Inputs - mask1: ImageField = Field(default=None, description="The first mask to combine") - mask2: ImageField = Field(default=None, description="The second image to combine") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mask Combine", "tags": ["mask", "combine"]}, - } + mask1: ImageField = InputField(description="The first mask to combine") + mask2: ImageField = InputField(description="The second image to combine") def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") @@ -728,7 +653,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): ) -class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): +@title("Color Correct") +@tags("image", "color") +class ColorCorrectInvocation(BaseInvocation): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -736,10 +663,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): type: Literal["color_correct"] = "color_correct" - image: Optional[ImageField] = Field(default=None, description="The image to color-correct") - reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction") - mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction") - mask_blur_radius: float = Field(default=8, description="Mask blur radius") + # Inputs + image: ImageField = InputField(description="The image to color-correct") + reference: ImageField = InputField(description="Reference image for color-correction") + mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") + mask_blur_radius: float = InputField(default=8, description="Mask blur radius") def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None @@ -833,16 +761,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): ) +@title("Image Hue Adjustment") +@tags("image", "hue", "hsl") class ImageHueAdjustmentInvocation(BaseInvocation): """Adjusts the Hue of an image.""" - # fmt: off type: Literal["img_hue_adjust"] = "img_hue_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) @@ -877,16 +805,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation): ) +@title("Image Luminosity Adjustment") +@tags("image", "luminosity", "hsl") class ImageLuminosityAdjustmentInvocation(BaseInvocation): """Adjusts the Luminosity (Value) of an image.""" - # fmt: off type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + luminosity: float = InputField( + default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)" + ) def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) @@ -925,16 +855,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation): ) +@title("Image Saturation Adjustment") +@tags("image", "saturation", "hsl") class ImageSaturationAdjustmentInvocation(BaseInvocation): """Adjusts the Saturation of an image.""" - # fmt: off type: Literal["img_saturation_adjust"] = "img_saturation_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index cd5b2f9a11..1547191f6c 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args import numpy as np import math from PIL import Image, ImageOps -from pydantic import Field +from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField -from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch -from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import ( - BaseInvocation, - InvocationConfig, - InvocationContext, -) +from ..models.image import ImageCategory, ResourceOrigin +from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags def infill_methods() -> list[str]: @@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si +@title("Solid Color Infill") +@tags("image", "inpaint") class InfillColorInvocation(BaseInvocation): """Infills transparent areas of an image with a solid color""" type: Literal["infill_rgba"] = "infill_rgba" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - color: ColorField = Field( + + # Inputs + image: ImageField = InputField(description="The image to infill") + color: ColorField = InputField( default=ColorField(r=127, g=127, b=127, a=255), description="The color to use to infill", ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]}, - } - def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation): ) +@title("Tile Infill") +@tags("image", "inpaint") class InfillTileInvocation(BaseInvocation): """Infills transparent areas of an image with tiles of the image""" type: Literal["infill_tile"] = "infill_tile" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - tile_size: int = Field(default=32, ge=1, description="The tile size (px)") - seed: int = Field( + # Input + image: ImageField = InputField(description="The image to infill") + tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") + seed: int = InputField( ge=0, le=SEED_MAX, description="The seed to use for tile generation (omit for random)", default_factory=get_random_seed, ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]}, - } - def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation): ) +@title("PatchMatch Infill") +@tags("image", "inpaint") class InfillPatchMatchInvocation(BaseInvocation): """Infills transparent areas of an image using the PatchMatch algorithm""" type: Literal["infill_patchmatch"] = "infill_patchmatch" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]}, - } + # Inputs + image: ImageField = InputField(description="The image to infill") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c66c9c6214..40f7af8703 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -13,16 +13,25 @@ from diffusers.models.attention_processor import ( LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler +from diffusers.schedulers import DPMSolverSDEScheduler +from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field, validator from torchvision.transforms.functional import resize as tv_resize from invokeai.app.invocations.metadata import CoreMetadata +from invokeai.app.invocations.primitives import ( + ImageField, + ImageOutput, + LatentsField, + LatentsOutput, + build_latents_output, +) from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings from ...backend.model_management import BaseModelType, ModelPatcher +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, @@ -32,48 +41,27 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( ) from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype -from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from ...backend.util.devices import choose_precision, choose_torch_device +from ..models.image import ImageCategory, ResourceOrigin +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UIType, + tags, + title, +) from .compel import ConditioningField from .controlnet_image_processors import ControlField -from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField DEFAULT_PRECISION = choose_precision(choose_torch_device()) -class LatentsField(BaseModel): - """A latents field used for passing latents between invocations""" - - latents_name: Optional[str] = Field(default=None, description="The name of the latents") - seed: Optional[int] = Field(description="Seed used to generate this latents") - - class Config: - schema_extra = {"required": ["latents_name"]} - - -class LatentsOutput(BaseInvocationOutput): - """Base class for invocations that output latents""" - - # fmt: off - type: Literal["latents_output"] = "latents_output" - - # Inputs - latents: LatentsField = Field(default=None, description="The output latents") - width: int = Field(description="The width of the latents in pixels") - height: int = Field(description="The height of the latents in pixels") - # fmt: on - - -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] @@ -111,30 +99,36 @@ def get_scheduler( return scheduler +@title("Denoise Latents") +@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l") class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" type: Literal["denoise_latents"] = "denoise_latents" # Inputs - positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") - negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") - noise: Optional[LatentsField] = Field(description="The noise to use") - steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: Union[float, List[float]] = Field( - default=7.5, - ge=1, - description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", + positive_conditioning: ConditioningField = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection ) - denoising_start: float = Field(default=0.0, ge=0, le=1, description="") - denoising_end: float = Field(default=1.0, ge=0, le=1, description="") - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use") - unet: UNetField = Field(default=None, description="UNet submodel") - control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - mask: Optional[ImageField] = Field( - None, - description="Mask", + negative_conditioning: ConditioningField = InputField( + description=FieldDescriptions.negative_cond, input=Input.Connection + ) + noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection) + steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) + cfg_scale: Union[float, List[float]] = InputField( + default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float + ) + denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler) + unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection) + control: Union[ControlField, list[ControlField]] = InputField( + default=None, description=FieldDescriptions.control, input=Input.Connection + ) + latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) + mask: Optional[ImageField] = InputField( + default=None, + description=FieldDescriptions.mask, ) @validator("cfg_scale") @@ -149,20 +143,6 @@ class DenoiseLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Denoise Latents", - "tags": ["denoise", "latents"], - "type_hints": { - "model": "model", - "control": "control", - "cfg_scale": "number", - }, - }, - } - # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( self, @@ -474,29 +454,29 @@ class DenoiseLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=result_latents, seed=seed) -# Latent to image +@title("Latents to Image") +@tags("latents", "image", "vae") class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") - vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") - metadata: Optional[CoreMetadata] = Field( - default=None, description="Optional core metadata to be written to the image" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) + metadata: CoreMetadata = InputField( + default=None, + description=FieldDescriptions.core_metadata, + ui_hidden=True, ) - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Latents To Image", - "tags": ["latents", "image"], - }, - } @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: @@ -574,24 +554,30 @@ class LatentsToImageInvocation(BaseInvocation): LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] +@title("Resize Latents") +@tags("latents", "resize") class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field( - default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, ) - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]}, - } + width: int = InputField( + ge=64, + multiple_of=8, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=8, + description=FieldDescriptions.width, + ) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -616,23 +602,21 @@ class ResizeLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) +@title("Scale Latents") +@tags("latents", "resize") class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field( - default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, ) - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]}, - } + scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -658,22 +642,23 @@ class ScaleLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) +@title("Image to Latents") +@tags("latents", "image", "vae") class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" type: Literal["i2l"] = "i2l" # Inputs - image: Optional[ImageField] = Field(description="The image to encode") - vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image To Latents", "tags": ["latents", "image"]}, - } + image: ImageField = InputField( + description="The image to encode", + ) + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 32b1ab2a39..13e3d92f52 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -2,134 +2,83 @@ from typing import Literal -from pydantic import BaseModel, Field import numpy as np -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - InvocationConfig, -) +from invokeai.app.invocations.primitives import IntegerOutput + +from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title -class MathInvocationConfig(BaseModel): - """Helper class to provide all math invocations with additional config""" - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["math"], - } - } - - -class IntOutput(BaseInvocationOutput): - """An integer output""" - - # fmt: off - type: Literal["int_output"] = "int_output" - a: int = Field(default=None, description="The output integer") - # fmt: on - - -class FloatOutput(BaseInvocationOutput): - """A float output""" - - # fmt: off - type: Literal["float_output"] = "float_output" - param: float = Field(default=None, description="The output float") - # fmt: on - - -class AddInvocation(BaseInvocation, MathInvocationConfig): +@title("Add Integers") +@tags("math") +class AddInvocation(BaseInvocation): """Adds two numbers""" - # fmt: off type: Literal["add"] = "add" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Add", "tags": ["math", "add"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=self.a + self.b) + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=self.a + self.b) -class SubtractInvocation(BaseInvocation, MathInvocationConfig): +@title("Subtract Integers") +@tags("math") +class SubtractInvocation(BaseInvocation): """Subtracts two numbers""" - # fmt: off type: Literal["sub"] = "sub" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Subtract", "tags": ["math", "subtract"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=self.a - self.b) + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=self.a - self.b) -class MultiplyInvocation(BaseInvocation, MathInvocationConfig): +@title("Multiply Integers") +@tags("math") +class MultiplyInvocation(BaseInvocation): """Multiplies two numbers""" - # fmt: off type: Literal["mul"] = "mul" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Multiply", "tags": ["math", "multiply"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=self.a * self.b) + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=self.a * self.b) -class DivideInvocation(BaseInvocation, MathInvocationConfig): +@title("Divide Integers") +@tags("math") +class DivideInvocation(BaseInvocation): """Divides two numbers""" - # fmt: off type: Literal["div"] = "div" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Divide", "tags": ["math", "divide"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=int(self.a / self.b)) + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=int(self.a / self.b)) +@title("Random Integer") +@tags("math") class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" - # fmt: off type: Literal["rand_int"] = "rand_int" - low: int = Field(default=0, description="The inclusive low value") - high: int = Field( - default=np.iinfo(np.int32).max, description="The exclusive high value" - ) - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]}, - } + # Inputs + low: int = InputField(default=0, description="The inclusive low value") + high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=np.random.randint(self.low, self.high)) + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=np.random.randint(self.low, self.high)) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index d0549f8539..679c610750 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -1,18 +1,22 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional from pydantic import Field -from ...version import __version__ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationConfig, + InputField, InvocationContext, + OutputField, + tags, + title, ) from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +from ...version import __version__ + class LoRAMetadataField(BaseModelExcludeNull): """LoRA metadata for an image generated in InvokeAI.""" @@ -43,37 +47,37 @@ class CoreMetadata(BaseModelExcludeNull): model: MainModelField = Field(description="The main model used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") - vae: Union[VAEModelField, None] = Field( + vae: Optional[VAEModelField] = Field( default=None, description="The VAE used for decoding, if the main model's default was not used", ) # Latents-to-Latents - strength: Union[float, None] = Field( + strength: Optional[float] = Field( default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field(default=None, description="The name of the initial image") + init_image: Optional[str] = Field(default=None, description="The name of the initial image") # SDXL - positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") - negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") + positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter") + negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter") # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") - refiner_cfg_scale: Union[float, None] = Field( + refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used") + refiner_cfg_scale: Optional[float] = Field( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") - refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") - refiner_positive_aesthetic_store: Union[float, None] = Field( + refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner") + refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner") + refiner_positive_aesthetic_store: Optional[float] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_negative_aesthetic_store: Union[float, None] = Field( + refiner_negative_aesthetic_store: Optional[float] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") + refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising") class ImageMetadata(BaseModelExcludeNull): @@ -91,69 +95,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput): type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output" - metadata: CoreMetadata = Field(description="The core metadata for the image") + metadata: CoreMetadata = OutputField(description="The core metadata for the image") +@title("Metadata Accumulator") +@tags("metadata") class MetadataAccumulatorInvocation(BaseInvocation): """Outputs a Core Metadata Object""" type: Literal["metadata_accumulator"] = "metadata_accumulator" - generation_mode: str = Field( + generation_mode: str = InputField( description="The generation mode that output this image", ) - positive_prompt: str = Field(description="The positive prompt parameter") - negative_prompt: str = Field(description="The negative prompt parameter") - width: int = Field(description="The width parameter") - height: int = Field(description="The height parameter") - seed: int = Field(description="The seed used for noise generation") - rand_device: str = Field(description="The device used for random number generation") - cfg_scale: float = Field(description="The classifier-free guidance scale parameter") - steps: int = Field(description="The number of steps used for inference") - scheduler: str = Field(description="The scheduler used for inference") - clip_skip: int = Field( + positive_prompt: str = InputField(description="The positive prompt parameter") + negative_prompt: str = InputField(description="The negative prompt parameter") + width: int = InputField(description="The width parameter") + height: int = InputField(description="The height parameter") + seed: int = InputField(description="The seed used for noise generation") + rand_device: str = InputField(description="The device used for random number generation") + cfg_scale: float = InputField(description="The classifier-free guidance scale parameter") + steps: int = InputField(description="The number of steps used for inference") + scheduler: str = InputField(description="The scheduler used for inference") + clip_skip: int = InputField( description="The number of skipped CLIP layers", ) - model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField] = Field(description="The ControlNets used for inference") - loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") - strength: Union[float, None] = Field( + model: MainModelField = InputField(description="The main model used for inference") + controlnets: list[ControlField] = InputField(description="The ControlNets used for inference") + loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference") + strength: Optional[float] = InputField( default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field(default=None, description="The name of the initial image") - vae: Union[VAEModelField, None] = Field( + init_image: Optional[str] = InputField( + default=None, + description="The name of the initial image", + ) + vae: Optional[VAEModelField] = InputField( default=None, description="The VAE used for decoding, if the main model's default was not used", ) # SDXL - positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") - negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") + positive_style_prompt: Optional[str] = InputField( + default=None, + description="The positive style prompt parameter", + ) + negative_style_prompt: Optional[str] = InputField( + default=None, + description="The negative style prompt parameter", + ) # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") - refiner_cfg_scale: Union[float, None] = Field( + refiner_model: Optional[MainModelField] = InputField( + default=None, + description="The SDXL Refiner model used", + ) + refiner_cfg_scale: Optional[float] = InputField( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") - refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") - refiner_positive_aesthetic_score: Union[float, None] = Field( - default=None, description="The aesthetic score used for the refiner" + refiner_steps: Optional[int] = InputField( + default=None, + description="The number of steps used for the refiner", ) - refiner_negative_aesthetic_score: Union[float, None] = Field( - default=None, description="The aesthetic score used for the refiner" + refiner_scheduler: Optional[str] = InputField( + default=None, + description="The scheduler used for the refiner", + ) + refiner_positive_aesthetic_store: Optional[float] = InputField( + default=None, + description="The aesthetic score used for the refiner", + ) + refiner_negative_aesthetic_store: Optional[float] = InputField( + default=None, + description="The aesthetic score used for the refiner", + ) + refiner_start: Optional[float] = InputField( + default=None, + description="The start value used for refiner denoising", ) - refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Metadata Accumulator", - "tags": ["image", "metadata", "generation"], - }, - } def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: """Collects and outputs a CoreMetadata object""" diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 0d21f8f0ce..484a4d71e1 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field from ...backend.model_management import BaseModelType, ModelType, SubModelType -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UIType, + tags, + title, +) class ModelInfo(BaseModel): @@ -39,13 +50,11 @@ class VaeField(BaseModel): class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - # fmt: off type: Literal["model_loader_output"] = "model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") class MainModelField(BaseModel): @@ -63,24 +72,17 @@ class LoRAModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") +@title("Main Model Loader") +@tags("model") class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" type: Literal["main_model_loader"] = "main_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Model Loader", - "tags": ["model", "loader"], - "type_hints": {"model": "model"}, - }, - } - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name @@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation): loras=[], skipped_layers=0, ), - clip2=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, - ), - loras=[], - skipped_layers=0, - ), vae=VaeField( vae=ModelInfo( model_name=model_name, @@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput): # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" - unet: Optional[UNetField] = Field(default=None, description="UNet submodel") - clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") # fmt: on +@title("LoRA Loader") +@tags("lora", "model") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") - weight: float = Field(default=0.75, description="With what weight to apply lora") - - unet: Optional[UNetField] = Field(description="UNet model for applying lora") - clip: Optional[ClipField] = Field(description="Clip model for applying lora") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Lora Loader", - "tags": ["lora", "loader"], - "type_hints": {"lora": "lora_model"}, - }, - } + # Inputs + lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) + unet: Optional[UNetField] = InputField( + default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet" + ) + clip: Optional[ClipField] = InputField( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP" + ) def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: @@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation): class SDXLLoraLoaderOutput(BaseInvocationOutput): - """Model loader output""" + """SDXL LoRA Loader Output""" # fmt: off type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" - unet: Optional[UNetField] = Field(default=None, description="UNet submodel") - clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") + unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1") + clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2") # fmt: on +@title("SDXL LoRA Loader") +@tags("sdxl", "lora", "model") class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" - lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") - weight: float = Field(default=0.75, description="With what weight to apply lora") - - unet: Optional[UNetField] = Field(description="UNet model for applying lora") - clip: Optional[ClipField] = Field(description="Clip model for applying lora") - clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Lora Loader", - "tags": ["lora", "loader"], - "type_hints": {"lora": "lora_model"}, - }, - } + lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight) + unet: Optional[UNetField] = Field( + default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET" + ) + clip: Optional[ClipField] = Field( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1" + ) + clip2: Optional[ClipField] = Field( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2" + ) def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: @@ -369,29 +350,23 @@ class VAEModelField(BaseModel): class VaeLoaderOutput(BaseInvocationOutput): """Model loader output""" - # fmt: off type: Literal["vae_loader_output"] = "vae_loader_output" - vae: VaeField = Field(default=None, description="Vae model") - # fmt: on + # Outputs + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") +@title("VAE Loader") +@tags("vae", "model") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" type: Literal["vae_loader"] = "vae_loader" - vae_model: VAEModelField = Field(description="The VAE to load") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "VAE Loader", - "tags": ["vae", "loader"], - "type_hints": {"vae_model": "vae_model"}, - }, - } + # Inputs + vae_model: VAEModelField = InputField( + description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE" + ) def invoke(self, context: InvocationContext) -> VaeLoaderOutput: base_model = self.vae_model.base_model diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index db64e5b6e5..6fae308fdc 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -1,19 +1,24 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -import math from typing import Literal -from pydantic import Field, validator import torch -from invokeai.app.invocations.latent import LatentsField +from pydantic import validator +from invokeai.app.invocations.latent import LatentsField from invokeai.app.util.misc import SEED_MAX, get_random_seed + from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationConfig, + FieldDescriptions, + InputField, InvocationContext, + OutputField, + UIType, + tags, + title, ) """ @@ -61,14 +66,12 @@ Nodes class NoiseOutput(BaseInvocationOutput): """Invocation noise output""" - # fmt: off - type: Literal["noise_output"] = "noise_output" + type: Literal["noise_output"] = "noise_output" # Inputs - noise: LatentsField = Field(default=None, description="The output noise") - width: int = Field(description="The width of the noise in pixels") - height: int = Field(description="The height of the noise in pixels") - # fmt: on + noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise) + width: int = OutputField(description=FieldDescriptions.width) + height: int = OutputField(description=FieldDescriptions.height) def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): @@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): ) +@title("Noise") +@tags("latents", "noise") class NoiseInvocation(BaseInvocation): """Generates latent noise.""" type: Literal["noise"] = "noise" # Inputs - seed: int = Field( + seed: int = InputField( ge=0, le=SEED_MAX, - description="The seed to use", + description=FieldDescriptions.seed, default_factory=get_random_seed, ) - width: int = Field( + width: int = InputField( default=512, multiple_of=8, gt=0, - description="The width of the resulting noise", + description=FieldDescriptions.width, ) - height: int = Field( + height: int = InputField( default=512, multiple_of=8, gt=0, - description="The height of the resulting noise", + description=FieldDescriptions.height, ) - use_cpu: bool = Field( + use_cpu: bool = InputField( default=True, description="Use CPU for noise generation (for reproducible results across platforms)", ) - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Noise", - "tags": ["latents", "noise"], - }, - } - @validator("seed", pre=True) def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 4f04a4f023..6cff56fb2e 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -1,37 +1,43 @@ # Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) +import inspect +import re from contextlib import ExitStack from typing import List, Literal, Optional, Union -import re -import inspect - -from pydantic import BaseModel, Field, validator -import torch import numpy as np +import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler - -from ..models.image import ImageCategory, ImageField, ResourceOrigin -from ...backend.model_management import ONNXModelPatcher -from ...backend.util import choose_torch_device -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .compel import ConditioningField -from .controlnet_image_processors import ControlField -from .image import ImageOutput -from .model import ModelInfo, UNetField, VaeField +from pydantic import BaseModel, Field, validator +from tqdm import tqdm from invokeai.app.invocations.metadata import CoreMetadata -from invokeai.backend import BaseModelType, ModelType, SubModelType +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend import BaseModelType, ModelType, SubModelType + +from ...backend.model_management import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState - -from tqdm import tqdm -from .model import ClipField -from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES -from .compel import CompelOutput - +from ...backend.util import choose_torch_device +from ..models.image import ImageCategory, ResourceOrigin +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UIComponent, + UIType, + tags, + title, +) +from .controlnet_image_processors import ControlField +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, @@ -51,13 +57,15 @@ ORT_TO_NP_TYPE = { PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] +@title("ONNX Prompt (Raw)") +@tags("onnx", "prompt") class ONNXPromptInvocation(BaseInvocation): type: Literal["prompt_onnx"] = "prompt_onnx" - prompt: str = Field(default="", description="Prompt") - clip: ClipField = Field(None, description="Clip to use") + prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context: InvocationContext) -> CompelOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) @@ -126,7 +134,7 @@ class ONNXPromptInvocation(BaseInvocation): # TODO: hacky but works ;D maybe rename latents somehow? context.services.latents.save(conditioning_name, (prompt_embeds, None)) - return CompelOutput( + return ConditioningOutput( conditioning=ConditioningField( conditioning_name=conditioning_name, ), @@ -134,25 +142,48 @@ class ONNXPromptInvocation(BaseInvocation): # Text to image +@title("ONNX Text to Latents") +@tags("latents", "inference", "txt2img", "onnx") class ONNXTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" type: Literal["t2l_onnx"] = "t2l_onnx" # Inputs - # fmt: off - positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") - negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") - noise: Optional[LatentsField] = Field(description="The noise to use") - steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) - precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents") - unet: UNetField = Field(default=None, description="UNet submodel") - control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") - # fmt: on + positive_conditioning: ConditioningField = InputField( + description=FieldDescriptions.positive_cond, + input=Input.Connection, + ) + negative_conditioning: ConditioningField = InputField( + description=FieldDescriptions.negative_cond, + input=Input.Connection, + ) + noise: LatentsField = InputField( + description=FieldDescriptions.noise, + input=Input.Connection, + ) + steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) + cfg_scale: Union[float, List[float]] = InputField( + default=7.5, + ge=1, + description=FieldDescriptions.cfg_scale, + ui_type=UIType.Float, + ) + scheduler: SAMPLER_NAME_VALUES = InputField( + default="euler", description=FieldDescriptions.scheduler, input=Input.Direct + ) + precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) + unet: UNetField = InputField( + description=FieldDescriptions.unet, + input=Input.Connection, + ) + control: Optional[Union[ControlField, list[ControlField]]] = InputField( + default=None, + description=FieldDescriptions.control, + ui_type=UIType.Control, + ) + # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") @validator("cfg_scale") def ge_one(cls, v): @@ -166,20 +197,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["latents"], - "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number", - }, - }, - } - # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -300,26 +317,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # Latent to image +@title("ONNX Latents to Image") +@tags("latents", "image", "vae", "onnx") class ONNXLatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" type: Literal["l2i_onnx"] = "l2i_onnx" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") - vae: VaeField = Field(default=None, description="Vae submodel") - metadata: Optional[CoreMetadata] = Field( - default=None, description="Optional core metadata to be written to the image" + latents: LatentsField = InputField( + description=FieldDescriptions.denoised_latents, + input=Input.Connection, ) - # tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["latents", "image"], - }, - } + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + metadata: Optional[CoreMetadata] = InputField( + default=None, + description=FieldDescriptions.core_metadata, + ui_hidden=True, + ) + # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -373,89 +392,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput): # fmt: off type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae_decoder: VaeField = Field(default=None, description="Vae submodel") - vae_encoder: VaeField = Field(default=None, description="Vae submodel") + unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") + vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") + vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") # fmt: on -class ONNXSD1ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" - - type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx" - - model_name: str = Field(default="", description="Model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name? - } - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - model_name = "stable-diffusion-v1-5" - base_model = BaseModelType.StableDiffusion1 - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.ONNX, - ): - raise Exception(f"Unkown model name: {model_name}!") - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.TextEncoder, - ), - loras=[], - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.VaeEncoder, - ), - ), - ) - - class OnnxModelField(BaseModel): """Onnx model field""" @@ -464,22 +407,17 @@ class OnnxModelField(BaseModel): model_type: ModelType = Field(description="Model Type") +@title("ONNX Model Loader") +@tags("onnx", "model") class OnnxModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" type: Literal["onnx_model_loader"] = "onnx_model_loader" - model: OnnxModelField = Field(description="The model to load") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Onnx Model Loader", - "tags": ["model", "loader"], - "type_hints": {"model": "model"}, - }, - } + # Inputs + model: OnnxModelField = InputField( + description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel + ) def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index f910e5379c..67f96b0c34 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -1,73 +1,64 @@ import io -from typing import Literal, Optional, Any +from typing import Literal, Optional -# from PIL.Image import Image -import PIL.Image -from matplotlib.ticker import MaxNLocator -from matplotlib.figure import Figure - -from pydantic import BaseModel, Field -import numpy as np import matplotlib.pyplot as plt +import numpy as np +import PIL.Image from easing_functions import ( - LinearInOut, - QuadEaseInOut, - QuadEaseIn, - QuadEaseOut, - CubicEaseInOut, - CubicEaseIn, - CubicEaseOut, - QuarticEaseInOut, - QuarticEaseIn, - QuarticEaseOut, - QuinticEaseInOut, - QuinticEaseIn, - QuinticEaseOut, - SineEaseInOut, - SineEaseIn, - SineEaseOut, - CircularEaseIn, - CircularEaseInOut, - CircularEaseOut, - ExponentialEaseInOut, - ExponentialEaseIn, - ExponentialEaseOut, - ElasticEaseIn, - ElasticEaseInOut, - ElasticEaseOut, BackEaseIn, BackEaseInOut, BackEaseOut, BounceEaseIn, BounceEaseInOut, BounceEaseOut, + CircularEaseIn, + CircularEaseInOut, + CircularEaseOut, + CubicEaseIn, + CubicEaseInOut, + CubicEaseOut, + ElasticEaseIn, + ElasticEaseInOut, + ElasticEaseOut, + ExponentialEaseIn, + ExponentialEaseInOut, + ExponentialEaseOut, + LinearInOut, + QuadEaseIn, + QuadEaseInOut, + QuadEaseOut, + QuarticEaseIn, + QuarticEaseInOut, + QuarticEaseOut, + QuinticEaseIn, + QuinticEaseInOut, + QuinticEaseOut, + SineEaseIn, + SineEaseInOut, + SineEaseOut, ) +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator +from pydantic import BaseModel, Field + +from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - InvocationConfig, -) from ...backend.util.logging import InvokeAILogger -from .collections import FloatCollectionOutput +from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title +@title("Float Range") +@tags("math", "range") class FloatLinearRangeInvocation(BaseInvocation): """Creates a range""" type: Literal["float_range"] = "float_range" # Inputs - start: float = Field(default=5, description="The first value of the range") - stop: float = Field(default=10, description="The last value of the range") - steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]}, - } + start: float = InputField(default=5, description="The first value of the range") + stop: float = InputField(default=10, description="The last value of the range") + steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) @@ -108,37 +99,32 @@ EASING_FUNCTIONS_MAP = { "BounceInOut": BounceEaseInOut, } -EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] +EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] # actually I think for now could just use CollectionOutput (which is list[Any] +@title("Step Param Easing") +@tags("step", "easing") class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" type: Literal["step_param_easing"] = "step_param_easing" # Inputs - # fmt: off - easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use") - num_steps: int = Field(default=20, description="number of denoising steps") - start_value: float = Field(default=0.0, description="easing starting value") - end_value: float = Field(default=1.0, description="easing ending value") - start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing") - end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing") + easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use") + num_steps: int = InputField(default=20, description="number of denoising steps") + start_value: float = InputField(default=0.0, description="easing starting value") + end_value: float = InputField(default=1.0, description="easing ending value") + start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing") + end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing") # if None, then start_value is used prior to easing start - pre_start_value: Optional[float] = Field(default=None, description="value before easing start") + pre_start_value: Optional[float] = InputField(default=None, description="value before easing start") # if None, then end value is used prior to easing end - post_end_value: Optional[float] = Field(default=None, description="value after easing end") - mirror: bool = Field(default=False, description="include mirror of easing function") + post_end_value: Optional[float] = InputField(default=None, description="value after easing end") + mirror: bool = InputField(default=False, description="include mirror of easing function") # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely - # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing") - show_easing_plot: bool = Field(default=False, description="show easing plot") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]}, - } + # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") + show_easing_plot: bool = InputField(default=False, description="show easing plot") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py deleted file mode 100644 index 513eb8762f..0000000000 --- a/invokeai/app/invocations/params.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from typing import Literal - -from pydantic import Field - -from invokeai.app.invocations.prompt import PromptOutput - -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .math import FloatOutput, IntOutput - -# Pass-through parameter nodes - used by subgraphs - - -class ParamIntInvocation(BaseInvocation): - """An integer parameter""" - - # fmt: off - type: Literal["param_int"] = "param_int" - a: int = Field(default=0, description="The integer value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"}, - } - - def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=self.a) - - -class ParamFloatInvocation(BaseInvocation): - """A float parameter""" - - # fmt: off - type: Literal["param_float"] = "param_float" - param: float = Field(default=0.0, description="The float value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "float"], "title": "Float Parameter"}, - } - - def invoke(self, context: InvocationContext) -> FloatOutput: - return FloatOutput(param=self.param) - - -class StringOutput(BaseInvocationOutput): - """A string output""" - - type: Literal["string_output"] = "string_output" - text: str = Field(default=None, description="The output string") - - -class ParamStringInvocation(BaseInvocation): - """A string parameter""" - - type: Literal["param_string"] = "param_string" - text: str = Field(default="", description="The string value") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "string"], "title": "String Parameter"}, - } - - def invoke(self, context: InvocationContext) -> StringOutput: - return StringOutput(text=self.text) - - -class ParamPromptInvocation(BaseInvocation): - """A prompt input parameter""" - - type: Literal["param_prompt"] = "param_prompt" - prompt: str = Field(default="", description="The prompt value") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "prompt"], "title": "Prompt"}, - } - - def invoke(self, context: InvocationContext) -> PromptOutput: - return PromptOutput(prompt=self.prompt) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py new file mode 100644 index 0000000000..398be04738 --- /dev/null +++ b/invokeai/app/invocations/primitives.py @@ -0,0 +1,494 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal, Optional, Tuple, Union +from anyio import Condition + +from pydantic import BaseModel, Field +import torch + +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UIComponent, + UIType, + tags, + title, +) + +""" +Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color +- primitive nodes +- primitive outputs +- primitive collection outputs +""" + +# region Boolean + + +class BooleanOutput(BaseInvocationOutput): + """Base class for nodes that output a single boolean""" + + type: Literal["boolean_output"] = "boolean_output" + a: bool = OutputField(description="The output boolean") + + +class BooleanCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of booleans""" + + type: Literal["boolean_collection_output"] = "boolean_collection_output" + + # Outputs + collection: list[bool] = OutputField( + default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection + ) + + +@title("Boolean") +@tags("primitives", "boolean") +class BooleanInvocation(BaseInvocation): + """A boolean primitive value""" + + type: Literal["boolean"] = "boolean" + + # Inputs + a: bool = InputField(default=False, description="The boolean value") + + def invoke(self, context: InvocationContext) -> BooleanOutput: + return BooleanOutput(a=self.a) + + +@title("Boolean Collection") +@tags("primitives", "boolean", "collection") +class BooleanCollectionInvocation(BaseInvocation): + """A collection of boolean primitive values""" + + type: Literal["boolean_collection"] = "boolean_collection" + + # Inputs + collection: list[bool] = InputField( + default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection + ) + + def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: + return BooleanCollectionOutput(collection=self.collection) + + +# endregion + +# region Integer + + +class IntegerOutput(BaseInvocationOutput): + """Base class for nodes that output a single integer""" + + type: Literal["integer_output"] = "integer_output" + a: int = OutputField(description="The output integer") + + +class IntegerCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of integers""" + + type: Literal["integer_collection_output"] = "integer_collection_output" + + # Outputs + collection: list[int] = OutputField( + default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection + ) + + +@title("Integer") +@tags("primitives", "integer") +class IntegerInvocation(BaseInvocation): + """An integer primitive value""" + + type: Literal["integer"] = "integer" + + # Inputs + a: int = InputField(default=0, description="The integer value") + + def invoke(self, context: InvocationContext) -> IntegerOutput: + return IntegerOutput(a=self.a) + + +@title("Integer Collection") +@tags("primitives", "integer", "collection") +class IntegerCollectionInvocation(BaseInvocation): + """A collection of integer primitive values""" + + type: Literal["integer_collection"] = "integer_collection" + + # Inputs + collection: list[int] = InputField( + default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection + ) + + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + return IntegerCollectionOutput(collection=self.collection) + + +# endregion + +# region Float + + +class FloatOutput(BaseInvocationOutput): + """Base class for nodes that output a single float""" + + type: Literal["float_output"] = "float_output" + a: float = OutputField(description="The output float") + + +class FloatCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of floats""" + + type: Literal["float_collection_output"] = "float_collection_output" + + # Outputs + collection: list[float] = OutputField( + default_factory=list, description="The float collection", ui_type=UIType.FloatCollection + ) + + +@title("Float") +@tags("primitives", "float") +class FloatInvocation(BaseInvocation): + """A float primitive value""" + + type: Literal["float"] = "float" + + # Inputs + param: float = InputField(default=0.0, description="The float value") + + def invoke(self, context: InvocationContext) -> FloatOutput: + return FloatOutput(a=self.param) + + +@title("Float Collection") +@tags("primitives", "float", "collection") +class FloatCollectionInvocation(BaseInvocation): + """A collection of float primitive values""" + + type: Literal["float_collection"] = "float_collection" + + # Inputs + collection: list[float] = InputField( + default=0, description="The collection of float values", ui_type=UIType.FloatCollection + ) + + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + return FloatCollectionOutput(collection=self.collection) + + +# endregion + +# region String + + +class StringOutput(BaseInvocationOutput): + """Base class for nodes that output a single string""" + + type: Literal["string_output"] = "string_output" + text: str = OutputField(description="The output string") + + +class StringCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of strings""" + + type: Literal["string_collection_output"] = "string_collection_output" + + # Outputs + collection: list[str] = OutputField( + default_factory=list, description="The output strings", ui_type=UIType.StringCollection + ) + + +@title("String") +@tags("primitives", "string") +class StringInvocation(BaseInvocation): + """A string primitive value""" + + type: Literal["string"] = "string" + + # Inputs + text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) + + def invoke(self, context: InvocationContext) -> StringOutput: + return StringOutput(text=self.text) + + +@title("String Collection") +@tags("primitives", "string", "collection") +class StringCollectionInvocation(BaseInvocation): + """A collection of string primitive values""" + + type: Literal["string_collection"] = "string_collection" + + # Inputs + collection: list[str] = InputField( + default=0, description="The collection of string values", ui_type=UIType.StringCollection + ) + + def invoke(self, context: InvocationContext) -> StringCollectionOutput: + return StringCollectionOutput(collection=self.collection) + + +# endregion + +# region Image + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class ImageOutput(BaseInvocationOutput): + """Base class for nodes that output a single image""" + + type: Literal["image_output"] = "image_output" + image: ImageField = OutputField(description="The output image") + width: int = OutputField(description="The width of the image in pixels") + height: int = OutputField(description="The height of the image in pixels") + + +class ImageCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of images""" + + type: Literal["image_collection_output"] = "image_collection_output" + + # Outputs + collection: list[ImageField] = OutputField( + default_factory=list, description="The output images", ui_type=UIType.ImageCollection + ) + + +@title("Image Primitive") +@tags("primitives", "image") +class ImageInvocation(BaseInvocation): + """An image primitive value""" + + # Metadata + type: Literal["image"] = "image" + + # Inputs + image: ImageField = InputField(description="The image to load") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image(self.image.image_name) + + return ImageOutput( + image=ImageField(image_name=self.image.image_name), + width=image.width, + height=image.height, + ) + + +@title("Image Collection") +@tags("primitives", "image", "collection") +class ImageCollectionInvocation(BaseInvocation): + """A collection of image primitive values""" + + type: Literal["image_collection"] = "image_collection" + + # Inputs + collection: list[ImageField] = InputField( + default=0, description="The collection of image values", ui_type=UIType.ImageCollection + ) + + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + return ImageCollectionOutput(collection=self.collection) + + +# endregion + +# region Latents + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class LatentsOutput(BaseInvocationOutput): + """Base class for nodes that output a single latents tensor""" + + type: Literal["latents_output"] = "latents_output" + + latents: LatentsField = OutputField( + description=FieldDescriptions.latents, + ) + width: int = OutputField(description=FieldDescriptions.width) + height: int = OutputField(description=FieldDescriptions.height) + + +class LatentsCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of latents tensors""" + + type: Literal["latents_collection_output"] = "latents_collection_output" + + collection: list[LatentsField] = OutputField( + default_factory=list, + description=FieldDescriptions.latents, + ui_type=UIType.LatentsCollection, + ) + + +@title("Latents Primitive") +@tags("primitives", "latents") +class LatentsInvocation(BaseInvocation): + """A latents tensor primitive value""" + + type: Literal["latents"] = "latents" + + # Inputs + latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + return build_latents_output(self.latents.latents_name, latents) + + +@title("Latents Collection") +@tags("primitives", "latents", "collection") +class LatentsCollectionInvocation(BaseInvocation): + """A collection of latents tensor primitive values""" + + type: Literal["latents_collection"] = "latents_collection" + + # Inputs + collection: list[LatentsField] = InputField( + default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection + ) + + def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: + return LatentsCollectionOutput(collection=self.collection) + + +def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): + return LatentsOutput( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + + +# endregion + +# region Color + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +class ColorOutput(BaseInvocationOutput): + """Base class for nodes that output a single color""" + + type: Literal["color_output"] = "color_output" + color: ColorField = OutputField(description="The output color") + + +class ColorCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of colors""" + + type: Literal["color_collection_output"] = "color_collection_output" + + # Outputs + collection: list[ColorField] = OutputField( + default_factory=list, description="The output colors", ui_type=UIType.ColorCollection + ) + + +@title("Color Primitive") +@tags("primitives", "color") +class ColorInvocation(BaseInvocation): + """A color primitive value""" + + type: Literal["color"] = "color" + + # Inputs + color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") + + def invoke(self, context: InvocationContext) -> ColorOutput: + return ColorOutput(color=self.color) + + +# endregion + +# region Conditioning + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + + +class ConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output a single conditioning tensor""" + + type: Literal["conditioning_output"] = "conditioning_output" + + conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + + +class ConditioningCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of conditioning tensors""" + + type: Literal["conditioning_collection_output"] = "conditioning_collection_output" + + # Outputs + collection: list[ConditioningField] = OutputField( + default_factory=list, + description="The output conditioning tensors", + ui_type=UIType.ConditioningCollection, + ) + + +@title("Conditioning Primitive") +@tags("primitives", "conditioning") +class ConditioningInvocation(BaseInvocation): + """A conditioning tensor primitive value""" + + type: Literal["conditioning"] = "conditioning" + + conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) + + def invoke(self, context: InvocationContext) -> ConditioningOutput: + return ConditioningOutput(conditioning=self.conditioning) + + +@title("Conditioning Collection") +@tags("primitives", "conditioning", "collection") +class ConditioningCollectionInvocation(BaseInvocation): + """A collection of conditioning tensor primitive values""" + + type: Literal["conditioning_collection"] = "conditioning_collection" + + # Inputs + collection: list[ConditioningField] = InputField( + default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection + ) + + def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: + return ConditioningCollectionOutput(collection=self.collection) + + +# endregion diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 83a397ddcf..acdb821456 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -1,59 +1,28 @@ from os.path import exists -from typing import Literal, Optional +from typing import Literal, Optional, Union import numpy as np -from pydantic import Field, validator +from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator +from pydantic import validator -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator - - -class PromptOutput(BaseInvocationOutput): - """Base class for invocations that output a prompt""" - - # fmt: off - type: Literal["prompt"] = "prompt" - - prompt: str = Field(default=None, description="The output prompt") - # fmt: on - - class Config: - schema_extra = { - "required": [ - "type", - "prompt", - ] - } - - -class PromptCollectionOutput(BaseInvocationOutput): - """Base class for invocations that output a collection of prompts""" - - # fmt: off - type: Literal["prompt_collection_output"] = "prompt_collection_output" - - prompt_collection: list[str] = Field(description="The output prompt collection") - count: int = Field(description="The size of the prompt collection") - # fmt: on - - class Config: - schema_extra = {"required": ["type", "prompt_collection", "count"]} +from invokeai.app.invocations.primitives import StringCollectionOutput + +from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title +@title("Dynamic Prompt") +@tags("prompt", "collection") class DynamicPromptInvocation(BaseInvocation): """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" type: Literal["dynamic_prompt"] = "dynamic_prompt" - prompt: str = Field(description="The prompt to parse with dynamicprompts") - max_prompts: int = Field(default=1, description="The number of prompts to generate") - combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]}, - } + # Inputs + prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea) + max_prompts: int = InputField(default=1, description="The number of prompts to generate") + combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context: InvocationContext) -> PromptCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -61,27 +30,26 @@ class DynamicPromptInvocation(BaseInvocation): generator = RandomPromptGenerator() prompts = generator.generate(self.prompt, num_images=self.max_prompts) - return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) + return StringCollectionOutput(collection=prompts) +@title("Prompts from File") +@tags("prompt", "file") class PromptsFromFileInvocation(BaseInvocation): """Loads prompts from a text file""" - # fmt: off - type: Literal['prompt_from_file'] = 'prompt_from_file' + type: Literal["prompt_from_file"] = "prompt_from_file" # Inputs - file_path: str = Field(description="Path to prompt text file") - pre_prompt: Optional[str] = Field(description="String to prepend to each prompt") - post_prompt: Optional[str] = Field(description="String to append to each prompt") - start_line: int = Field(default=1, ge=1, description="Line in the file to start start from") - max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Prompts From File", "tags": ["prompt", "file"]}, - } + file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath) + pre_prompt: Optional[str] = InputField( + default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea + ) + post_prompt: Optional[str] = InputField( + default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea + ) + start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from") + max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)") @validator("file_path") def file_path_exists(cls, v): @@ -89,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation): raise ValueError(FileNotFoundError) return v - def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int): + def promptsFromFile( + self, + file_path: str, + pre_prompt: Union[str, None], + post_prompt: Union[str, None], + start_line: int, + max_prompts: int, + ): prompts = [] start_line -= 1 end_line = start_line + max_prompts @@ -103,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation): break return prompts - def invoke(self, context: InvocationContext) -> PromptCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts ) - return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) + return StringCollectionOutput(collection=prompts) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index a5a1c2c641..4efe30a3d9 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,55 +1,55 @@ -import torch from typing import Literal -from pydantic import Field from ...backend.model_management import ModelType, SubModelType -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UIType, + tags, + title, +) +from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField class SDXLModelLoaderOutput(BaseInvocationOutput): """SDXL base model loader output""" - # fmt: off type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") + clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): """SDXL refiner model loader output""" - # fmt: off type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on - # fmt: on + + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") +@title("SDXL Main Model Loader") +@tags("model", "sdxl") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" type: Literal["sdxl_model_loader"] = "sdxl_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField( + description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel + ) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Model Loader", - "tags": ["model", "loader", "sdxl"], - "type_hints": {"model": "model"}, - }, - } - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name @@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) +@title("SDXL Refiner Model Loader") +@tags("model", "sdxl", "refiner") class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField( + description=FieldDescriptions.sdxl_refiner_model, + input=Input.Direct, + ui_type=UIType.SDXLRefinerModel, + ) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Refiner Model Loader", - "tags": ["model", "loader", "sdxl_refiner"], - "type_hints": {"model": "refiner_model"}, - }, - } - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index fd220223db..cd4463e174 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -6,13 +6,12 @@ import cv2 as cv import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image -from pydantic import Field from realesrgan import RealESRGANer +from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin +from invokeai.app.models.image import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext -from .image import ImageOutput +from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags # TODO: Populate this from disk? # TODO: Use model manager to load? @@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[ ] +@title("Upscale (RealESRGAN)") +@tags("esrgan", "upscale") class ESRGANInvocation(BaseInvocation): """Upscales an image using RealESRGAN.""" type: Literal["esrgan"] = "esrgan" - image: Union[ImageField, None] = Field(default=None, description="The input image") - model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]}, - } + # Inputs + image: ImageField = InputField(description="The input image") + model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 2a5a0f9d3b..88cf8af5f9 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -1,31 +1,8 @@ from enum import Enum -from typing import Optional, Tuple, Literal + from pydantic import BaseModel, Field from invokeai.app.util.metaenum import MetaEnum -from ..invocations.baseinvocation import ( - BaseInvocationOutput, - InvocationConfig, -) - - -class ImageField(BaseModel): - """An image field used for passing image objects between invocations""" - - image_name: Optional[str] = Field(default=None, description="The name of the image") - - class Config: - schema_extra = {"required": ["image_name"]} - - -class ColorField(BaseModel): - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) class ProgressImage(BaseModel): @@ -36,50 +13,6 @@ class ProgressImage(BaseModel): dataURL: str = Field(description="The image data as a b64 data URL") -class PILInvocationConfig(BaseModel): - """Helper class to provide all PIL invocations with additional config""" - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["PIL", "image"], - }, - } - - -class ImageOutput(BaseInvocationOutput): - """Base class for invocations that output an image""" - - # fmt: off - type: Literal["image_output"] = "image_output" - image: ImageField = Field(default=None, description="The output image") - width: int = Field(description="The width of the image in pixels") - height: int = Field(description="The height of the image in pixels") - # fmt: on - - class Config: - schema_extra = {"required": ["type", "image", "width", "height"]} - - -class MaskOutput(BaseInvocationOutput): - """Base class for invocations that output a mask""" - - # fmt: off - type: Literal["mask"] = "mask" - mask: ImageField = Field(default=None, description="The output mask") - width: int = Field(description="The width of the mask in pixels") - height: int = Field(description="The height of the mask in pixels") - # fmt: on - - class Config: - schema_extra = { - "required": [ - "type", - "mask", - ] - } - - class ResourceOrigin(str, Enum, metaclass=MetaEnum): """The origin of a resource (eg image). diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py index caee5b631e..7135e031b0 100644 --- a/invokeai/app/services/default_graphs.py +++ b/invokeai/app/services/default_graphs.py @@ -2,7 +2,7 @@ from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocat from ..invocations.image import ImageNSFWBlurInvocation from ..invocations.noise import NoiseInvocation from ..invocations.compel import CompelInvocation -from ..invocations.params import ParamIntInvocation +from ..invocations.primitives import IntegerInvocation from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .item_storage import ItemStorageABC @@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph: description="Converts text to an image", graph=Graph( nodes={ - "width": ParamIntInvocation(id="width", a=512), - "height": ParamIntInvocation(id="height", a=512), - "seed": ParamIntInvocation(id="seed", a=-1), + "width": IntegerInvocation(id="width", a=512), + "height": IntegerInvocation(id="height", a=512), + "seed": IntegerInvocation(id="seed", a=-1), "3": NoiseInvocation(id="3"), "4": CompelInvocation(id="4"), "5": CompelInvocation(id="5"), diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index d7f021df14..c412431b43 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -3,16 +3,7 @@ import copy import itertools import uuid -from typing import ( - Annotated, - Any, - Literal, - Optional, - Union, - get_args, - get_origin, - get_type_hints, -) +from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import BaseModel, root_validator, validator @@ -22,7 +13,11 @@ from ..invocations import * from ..invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Input, + InputField, InvocationContext, + OutputField, + UIType, ) # in 3.10 this would be "from types import NoneType" @@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput): type: Literal["iterate_output"] = "iterate_output" - item: Any = Field(description="The item being iterated over") - - class Config: - schema_extra = { - "required": [ - "type", - "item", - ] - } + item: Any = OutputField( + description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem + ) # TODO: Fill this out and move to invocations @@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation): type: Literal["iterate"] = "iterate" - collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) - index: int = Field(description="The index, will be provided on executed iterators", default=0) + collection: list[Any] = InputField( + description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection + ) + index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" @@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation): class CollectInvocationOutput(BaseInvocationOutput): type: Literal["collect_output"] = "collect_output" - collection: list[Any] = Field(description="The collection of input items") - - class Config: - schema_extra = { - "required": [ - "type", - "collection", - ] - } + collection: list[Any] = OutputField( + description="The collection of input items", title="Collection", ui_type=UIType.Collection + ) class CollectInvocation(BaseInvocation): @@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation): type: Literal["collect"] = "collect" - item: Any = Field( + item: Any = InputField( description="The item to collect (all inputs must be of the same type)", - default=None, + ui_type=UIType.CollectionItem, + title="Collection Item", + input=Input.Connection, ) - collection: list[Any] = Field( - description="The collection, will be provided on execution", - default_factory=list, + collection: list[Any] = InputField( + description="The collection, will be provided on execution", default_factory=list, ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index c781911aa5..239c1392d9 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join( "created_at", "updated_at", "deleted_at", + "starred", ], ) ) @@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC): node_id: Optional[str], metadata: Optional[dict], is_intermediate: bool = False, + starred: bool = False, ) -> datetime: """Saves an image record.""" pass @@ -198,6 +200,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) + self._cursor.execute("PRAGMA table_info(images)") + columns = [column[1] for column in self._cursor.fetchall()] + + if "starred" not in columns: + self._cursor.execute( + """--sql + ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE; + """ + ) + # Create the `images` table indices. self._cursor.execute( """--sql @@ -220,6 +232,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred); + """ + ) + # Add trigger for `updated_at`. self._cursor.execute( """--sql @@ -319,6 +337,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): (changes.is_intermediate, image_name), ) + # Change the image's `starred`` state + if changes.starred is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET starred = ? + WHERE image_name = ?; + """, + (changes.starred, image_name), + ) + self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -395,7 +424,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_params.append(board_id) query_pagination = """--sql - ORDER BY images.created_at DESC LIMIT ? OFFSET ? + ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ? """ # Final images query with pagination @@ -498,6 +527,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id: Optional[str], metadata: Optional[dict], is_intermediate: bool = False, + starred: bool = False, ) -> datetime: try: metadata_json = None if metadata is None else json.dumps(metadata) @@ -513,9 +543,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id, session_id, metadata, - is_intermediate + is_intermediate, + starred ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, @@ -527,6 +558,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id, metadata_json, is_intermediate, + starred, ), ) self._conn.commit() diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 294b760630..a480fd0efd 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull): description="The node ID that generated this image, if it is a generated image.", ) """The node ID that generated this image, if it is a generated image.""" + starred: bool = Field(description="Whether this image is starred.") + """Whether this image is starred.""" class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): @@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): - `image_category`: change the category of an image - `session_id`: change the session associated with an image - `is_intermediate`: change the image's `is_intermediate` flag + - `starred`: change whether the image is starred """ image_category: Optional[ImageCategory] = Field(description="The image's new category.") @@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): """The image's new session ID.""" is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.") """The image's new `is_intermediate` flag.""" + starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state") + """The image's new `starred` state.""" class ImageUrlsDTO(BaseModelExcludeNull): @@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) + starred = image_dict.get("starred", False) return ImageRecord( image_name=image_name, @@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, + starred=starred, ) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 41170a304b..b8c2f93e93 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: with statistics.collect_stats(invocation, graph_execution_state.id): - outputs = invocation.invoke( + # use the internal invoke_internal(), which wraps the node's invoke() method in + # this accomodates nodes which require a value, but get it only from a + # connection + outputs = invocation.invoke_internal( InvocationContext( services=self.__invoker.services, graph_execution_state_id=graph_execution_state.id, diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 251964dafd..99a0df2ce5 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -45,7 +45,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def _parse_item(self, item: str) -> T: item_type = get_args(self.__orig_class__)[0] - return parse_raw_as(item_type, item) + parsed = parse_raw_as(item_type, item) + return parsed def set(self, item: T): try: diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 8cc2c158be..6c9db74bbc 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -61,6 +61,7 @@ "@dagrejs/graphlib": "^2.1.13", "@dnd-kit/core": "^6.0.8", "@dnd-kit/modifiers": "^6.0.1", + "@dnd-kit/utilities": "^3.2.1", "@emotion/react": "^11.11.1", "@emotion/styled": "^11.11.0", "@floating-ui/react-dom": "^2.0.1", diff --git a/invokeai/frontend/web/scripts/colors.js b/invokeai/frontend/web/scripts/colors.js new file mode 100644 index 0000000000..3fc8f8d751 --- /dev/null +++ b/invokeai/frontend/web/scripts/colors.js @@ -0,0 +1,34 @@ +export const COLORS = { + reset: '\x1b[0m', + bright: '\x1b[1m', + dim: '\x1b[2m', + underscore: '\x1b[4m', + blink: '\x1b[5m', + reverse: '\x1b[7m', + hidden: '\x1b[8m', + + fg: { + black: '\x1b[30m', + red: '\x1b[31m', + green: '\x1b[32m', + yellow: '\x1b[33m', + blue: '\x1b[34m', + magenta: '\x1b[35m', + cyan: '\x1b[36m', + white: '\x1b[37m', + gray: '\x1b[90m', + crimson: '\x1b[38m', + }, + bg: { + black: '\x1b[40m', + red: '\x1b[41m', + green: '\x1b[42m', + yellow: '\x1b[43m', + blue: '\x1b[44m', + magenta: '\x1b[45m', + cyan: '\x1b[46m', + white: '\x1b[47m', + gray: '\x1b[100m', + crimson: '\x1b[48m', + }, +}; diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index ec67c48f2d..d105917e66 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -1,23 +1,83 @@ import fs from 'node:fs'; import openapiTS from 'openapi-typescript'; +import { COLORS } from './colors.js'; const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json'; const OUTPUT_FILE = 'src/services/api/schema.d.ts'; async function main() { process.stdout.write( - `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...` + `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n` ); const types = await openapiTS(OPENAPI_URL, { exportType: true, - transform: (schemaObject) => { + transform: (schemaObject, metadata) => { if ('format' in schemaObject && schemaObject.format === 'binary') { return schemaObject.nullable ? 'Blob | null' : 'Blob'; } + + /** + * Because invocations may have required fields that accept connection input, the generated + * types may be incorrect. + * + * For example, the ImageResizeInvocation has a required `image` field, but because it accepts + * connection input, it should be optional on instantiation of the field. + * + * To handle this, the schema exposes an `input` property that can be used to determine if the + * field accepts connection input. If it does, we can make the field optional. + */ + + // Check if we are generating types for an invocation + const isInvocationPath = metadata.path.match( + /^#\/components\/schemas\/\w*Invocation$/ + ); + + const hasInvocationProperties = + schemaObject.properties && + ['id', 'is_intermediate', 'type'].every( + (prop) => prop in schemaObject.properties + ); + + if (isInvocationPath && hasInvocationProperties) { + // We only want to make fields optional if they are required + if (!Array.isArray(schemaObject?.required)) { + schemaObject.required = ['id', 'type']; + return; + } + + schemaObject.required.forEach((prop) => { + const acceptsConnection = ['any', 'connection'].includes( + schemaObject.properties?.[prop]?.['input'] + ); + + if (acceptsConnection) { + // remove this prop from the required array + const invocationName = metadata.path.split('/').pop(); + console.log( + `Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}` + ); + schemaObject.required = schemaObject.required.filter( + (r) => r !== prop + ); + } + }); + + schemaObject.required = [ + ...new Set(schemaObject.required.concat(['id', 'type'])), + ]; + + return; + } + // if ( + // 'input' in schemaObject && + // (schemaObject.input === 'any' || schemaObject.input === 'connection') + // ) { + // schemaObject.required = false; + // } }, }); fs.writeFileSync(OUTPUT_FILE, types); - process.stdout.write(` OK!\r\n`); + process.stdout.write(`\nOK!\r\n`); } main(); diff --git a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts index 9827e7f2b3..bbe77dc698 100644 --- a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts +++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts @@ -1,8 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; +import { + ctrlKeyPressed, + metaKeyPressed, + shiftKeyPressed, +} from 'features/ui/store/hotkeysSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { setActiveTab, @@ -16,11 +20,11 @@ import React, { memo } from 'react'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; const globalHotkeysSelector = createSelector( - [(state: RootState) => state.hotkeys, (state: RootState) => state.ui], - (hotkeys, ui) => { - const { shift } = hotkeys; + [stateSelector], + ({ hotkeys, ui }) => { + const { shift, ctrl, meta } = hotkeys; const { shouldPinParametersPanel, shouldPinGallery } = ui; - return { shift, shouldPinGallery, shouldPinParametersPanel }; + return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel }; }, { memoizeOptions: { @@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector( */ const GlobalHotkeys: React.FC = () => { const dispatch = useAppDispatch(); - const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector( - globalHotkeysSelector - ); + const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } = + useAppSelector(globalHotkeysSelector); const activeTabName = useAppSelector(activeTabNameSelector); useHotkeys( @@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => { } else { shift && dispatch(shiftKeyPressed(false)); } + if (isHotkeyPressed('ctrl')) { + !ctrl && dispatch(ctrlKeyPressed(true)); + } else { + ctrl && dispatch(ctrlKeyPressed(false)); + } + if (isHotkeyPressed('meta')) { + !meta && dispatch(metaKeyPressed(true)); + } else { + meta && dispatch(metaKeyPressed(false)); + } }, { keyup: true, keydown: true }, - [shift] + [shift, ctrl, meta] ); useHotkeys('o', () => { diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 93b7825db7..7e2ed7f571 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client'; import { socketMiddleware } from 'services/events/middleware'; import Loading from '../../common/components/Loading/Loading'; import '../../i18n'; -import ImageDndContext from './ImageDnd/ImageDndContext'; +import AppDndContext from '../../features/dnd/components/AppDndContext'; const App = lazy(() => import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -80,9 +80,9 @@ const InvokeAIUI = ({ }> - + - + diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index ef27c98d1f..7797b8dc92 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -19,7 +19,8 @@ type LoggerNamespace = | 'nodes' | 'system' | 'socketio' - | 'session'; + | 'session' + | 'dnd'; export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace }); diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts index 6d41d488c8..a596fce931 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts @@ -15,7 +15,7 @@ export const actionsDenylist = [ 'socket/socketGeneratorProgress', 'socket/appSocketGeneratorProgress', // every time user presses shift - 'hotkeys/shiftKeyPressed', + // 'hotkeys/shiftKeyPressed', // this happens after every state change '@@REMEMBER_PERSISTED', ]; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index c15b072a07..abb17d1eec 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -15,6 +15,7 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; +import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery'; import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; @@ -27,8 +28,8 @@ import { addImageDeletedFulfilledListener, addImageDeletedPendingListener, addImageDeletedRejectedListener, - addRequestedSingleImageDeletionListener, addRequestedMultipleImageDeletionListener, + addRequestedSingleImageDeletionListener, } from './listeners/imageDeleted'; import { addImageDroppedListener } from './listeners/imageDropped'; import { @@ -79,6 +80,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addImagesStarredListener } from './listeners/imagesStarred'; +import { addImagesUnstarredListener } from './listeners/imagesUnstarred'; export const listenerMiddleware = createListenerMiddleware(); @@ -120,6 +123,10 @@ addImageDeletedRejectedListener(); addDeleteBoardAndImagesFulfilledListener(); addImageToDeleteSelectedListener(); +// Image starred +addImagesStarredListener(); +addImagesUnstarredListener(); + // User Invoked addUserInvokedCanvasListener(); addUserInvokedNodesListener(); @@ -129,6 +136,7 @@ addSessionReadyToInvokeListener(); // Canvas actions addCanvasSavedToGalleryListener(); +addCanvasMaskSavedToGalleryListener(); addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts new file mode 100644 index 0000000000..e701b93352 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts @@ -0,0 +1,60 @@ +import { logger } from 'app/logging/logger'; +import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; +import { getCanvasData } from 'features/canvas/util/getCanvasData'; +import { addToast } from 'features/system/store/systemSlice'; +import { imagesApi } from 'services/api/endpoints/images'; +import { startAppListening } from '..'; + +export const addCanvasMaskSavedToGalleryListener = () => { + startAppListening({ + actionCreator: canvasMaskSavedToGallery, + effect: async (action, { dispatch, getState }) => { + const log = logger('canvas'); + const state = getState(); + + const canvasBlobsAndImageData = await getCanvasData( + state.canvas.layerState, + state.canvas.boundingBoxCoordinates, + state.canvas.boundingBoxDimensions, + state.canvas.isMaskEnabled, + state.canvas.shouldPreserveMaskedArea + ); + + if (!canvasBlobsAndImageData) { + return; + } + + const { maskBlob } = canvasBlobsAndImageData; + + if (!maskBlob) { + log.error('Problem getting mask layer blob'); + dispatch( + addToast({ + title: 'Problem Saving Mask', + description: 'Unable to export mask', + status: 'error', + }) + ); + return; + } + + const { autoAddBoardId } = state.gallery; + + dispatch( + imagesApi.endpoints.uploadImage.initiate({ + file: new File([maskBlob], 'canvasMaskImage.png', { + type: 'image/png', + }), + image_category: 'mask', + is_intermediate: false, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, + crop_visible: true, + postUploadAction: { + type: 'TOAST', + toastOptions: { title: 'Mask Saved to Assets' }, + }, + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 043105cb66..fc0b44653d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -1,16 +1,20 @@ import { createAction } from '@reduxjs/toolkit'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + fieldImageValueChanged, + workflowExposedFieldAdded, +} from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '../'; +import { parseify } from 'common/util/serialize'; export const dndDropped = createAction<{ overData: TypesafeDroppableData; @@ -21,7 +25,7 @@ export const addImageDroppedListener = () => { startAppListening({ actionCreator: dndDropped, effect: async (action, { dispatch }) => { - const log = logger('images'); + const log = logger('dnd'); const { activeData, overData } = action.payload; if (activeData.payloadType === 'IMAGE_DTO') { @@ -31,10 +35,28 @@ export const addImageDroppedListener = () => { { activeData, overData }, `Images (${activeData.payload.imageDTOs.length}) dropped` ); + } else if (activeData.payloadType === 'NODE_FIELD') { + log.debug( + { activeData: parseify(activeData), overData: parseify(overData) }, + 'Node field dropped' + ); } else { log.debug({ activeData, overData }, `Unknown payload dropped`); } + if ( + overData.actionType === 'ADD_FIELD_TO_LINEAR' && + activeData.payloadType === 'NODE_FIELD' + ) { + const { nodeId, field } = activeData.payload; + dispatch( + workflowExposedFieldAdded({ + nodeId, + fieldName: field.name, + }) + ); + } + /** * Image dropped on current image */ @@ -99,7 +121,7 @@ export const addImageDroppedListener = () => { ) { const { fieldName, nodeId } = overData.context; dispatch( - fieldValueChanged({ + fieldImageValueChanged({ nodeId, fieldName, value: activeData.payload.imageDTO, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 6dc2d482a9..0c55908748 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react'; import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { addToast } from 'features/system/store/systemSlice'; import { omit } from 'lodash-es'; @@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => { if (postUploadAction?.type === 'SET_NODES_IMAGE') { const { nodeId, fieldName } = postUploadAction; - dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO })); + dispatch( + fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }) + ); dispatch( addToast({ ...DEFAULT_UPLOADED_TOAST, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesStarred.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesStarred.ts new file mode 100644 index 0000000000..5988eee207 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesStarred.ts @@ -0,0 +1,30 @@ +import { imagesApi } from 'services/api/endpoints/images'; +import { startAppListening } from '..'; +import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice'; +import { ImageDTO } from '../../../../../services/api/types'; + +export const addImagesStarredListener = () => { + startAppListening({ + matcher: imagesApi.endpoints.starImages.matchFulfilled, + effect: async (action, { dispatch, getState }) => { + const { updated_image_names: starredImages } = action.payload; + + const state = getState(); + + const { selection } = state.gallery; + const updatedSelection: ImageDTO[] = []; + + selection.forEach((selectedImageDTO) => { + if (starredImages.includes(selectedImageDTO.image_name)) { + updatedSelection.push({ + ...selectedImageDTO, + starred: true, + }); + } else { + updatedSelection.push(selectedImageDTO); + } + }); + dispatch(selectionChanged(updatedSelection)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesUnstarred.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesUnstarred.ts new file mode 100644 index 0000000000..3df76861d4 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imagesUnstarred.ts @@ -0,0 +1,30 @@ +import { imagesApi } from 'services/api/endpoints/images'; +import { startAppListening } from '..'; +import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice'; +import { ImageDTO } from '../../../../../services/api/types'; + +export const addImagesUnstarredListener = () => { + startAppListening({ + matcher: imagesApi.endpoints.unstarImages.matchFulfilled, + effect: async (action, { dispatch, getState }) => { + const { updated_image_names: unstarredImages } = action.payload; + + const state = getState(); + + const { selection } = state.gallery; + const updatedSelection: ImageDTO[] = []; + + selection.forEach((selectedImageDTO) => { + if (unstarredImages.includes(selectedImageDTO.image_name)) { + updatedSelection.push({ + ...selectedImageDTO, + starred: false, + }); + } else { + updatedSelection.push(selectedImageDTO); + } + }); + dispatch(selectionChanged(updatedSelection)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 436a58aa8e..4d30ee3b8b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -15,12 +15,21 @@ import { setShouldUseSDXLRefiner, } from 'features/sdxl/store/sdxlSlice'; import { forEach, some } from 'lodash-es'; -import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models'; +import { + mainModelsAdapter, + modelsApi, + vaeModelsAdapter, +} from 'services/api/endpoints/models'; +import { TypeGuardFor } from 'services/api/types'; import { startAppListening } from '..'; export const addModelsLoadedListener = () => { startAppListening({ - predicate: (state, action) => + predicate: ( + action + ): action is TypeGuardFor< + typeof modelsApi.endpoints.getMainModels.matchFulfilled + > => modelsApi.endpoints.getMainModels.matchFulfilled(action) && !action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { @@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => { ); const currentModel = getState().generation.model; + const models = mainModelsAdapter.getSelectors().selectAll(action.payload); - const isCurrentModelAvailable = some( - action.payload.entities, - (m) => - m?.model_name === currentModel?.model_name && - m?.base_model === currentModel?.base_model && - m?.model_type === currentModel?.model_type - ); - - if (isCurrentModelAvailable) { - return; - } - - const firstModelId = action.payload.ids[0]; - const firstModel = action.payload.entities[firstModelId]; - - if (!firstModel) { + if (models.length === 0) { // No models loaded at all dispatch(modelChanged(null)); return; } - const result = zMainOrOnnxModel.safeParse(firstModel); + const isCurrentModelAvailable = currentModel + ? models.some( + (m) => + m.model_name === currentModel.model_name && + m.base_model === currentModel.base_model && + m.model_type === currentModel.model_type + ) + : false; + + if (isCurrentModelAvailable) { + return; + } + + const result = zMainOrOnnxModel.safeParse(models[0]); if (!result.success) { log.error( @@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => { }, }); startAppListening({ - predicate: (state, action) => + predicate: ( + action + ): action is TypeGuardFor< + typeof modelsApi.endpoints.getMainModels.matchFulfilled + > => modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { @@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => { ); const currentModel = getState().sdxl.refinerModel; + const models = mainModelsAdapter.getSelectors().selectAll(action.payload); - const isCurrentModelAvailable = some( - action.payload.entities, - (m) => - m?.model_name === currentModel?.model_name && - m?.base_model === currentModel?.base_model && - m?.model_type === currentModel?.model_type - ); - - if (isCurrentModelAvailable) { - return; - } - - const firstModelId = action.payload.ids[0]; - const firstModel = action.payload.entities[firstModelId]; - - if (!firstModel) { + if (models.length === 0) { // No models loaded at all dispatch(refinerModelChanged(null)); dispatch(setShouldUseSDXLRefiner(false)); return; } - const result = zSDXLRefinerModel.safeParse(firstModel); + const isCurrentModelAvailable = currentModel + ? models.some( + (m) => + m.model_name === currentModel.model_name && + m.base_model === currentModel.base_model && + m.model_type === currentModel.model_type + ) + : false; + + if (isCurrentModelAvailable) { + return; + } + + const result = zSDXLRefinerModel.safeParse(models[0]); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index 44729f215a..dd86c77735 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => { const log = logger('system'); const schemaJSON = action.payload; - log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema'); + log.debug({ schemaJSON }, 'Received OpenAPI schema'); const nodeTemplates = parseSchema(schemaJSON); @@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => { startAppListening({ actionCreator: receivedOpenAPISchema.rejected, - effect: () => { + effect: (action) => { const log = logger('system'); - log.error('Problem dereferencing OpenAPI Schema'); + log.error( + { error: parseify(action.error) }, + 'Problem retrieving OpenAPI Schema' + ); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 5b3b9424b6..5501f208fd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -19,7 +19,7 @@ import { } from 'services/events/actions'; import { startAppListening } from '../..'; -const nodeDenylist = ['dataURL_image']; +const nodeDenylist = ['load_image']; export const addInvocationCompleteEventListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts index 0c298cbb24..5894bba5df 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts @@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => { const log = logger('session'); const state = getState(); - const graph = buildNodesGraph(state); + const graph = buildNodesGraph(state.nodes); dispatch(nodesGraphBuilt(graph)); log.debug({ graph: parseify(graph) }, 'Nodes graph built'); diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 827424fa7f..a39ed2ca7b 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,86 +1,7 @@ -import { - // CONTROLNET_MODELS, - CONTROLNET_PROCESSORS, -} from 'features/controlNet/store/constants'; +import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { O } from 'ts-toolbelt'; -// These are old types from the model management UI - -// export type ModelStatus = 'active' | 'cached' | 'not loaded'; - -// export type Model = { -// status: ModelStatus; -// description: string; -// weights: string; -// config?: string; -// vae?: string; -// width?: number; -// height?: number; -// default?: boolean; -// format?: string; -// }; - -// export type DiffusersModel = { -// status: ModelStatus; -// description: string; -// repo_id?: string; -// path?: string; -// vae?: { -// repo_id?: string; -// path?: string; -// }; -// format?: string; -// default?: boolean; -// }; - -// export type ModelList = Record; - -// export type FoundModel = { -// name: string; -// location: string; -// }; - -// export type InvokeModelConfigProps = { -// name: string | undefined; -// description: string | undefined; -// config: string | undefined; -// weights: string | undefined; -// vae: string | undefined; -// width: number | undefined; -// height: number | undefined; -// default: boolean | undefined; -// format: string | undefined; -// }; - -// export type InvokeDiffusersModelConfigProps = { -// name: string | undefined; -// description: string | undefined; -// repo_id: string | undefined; -// path: string | undefined; -// default: boolean | undefined; -// format: string | undefined; -// vae: { -// repo_id: string | undefined; -// path: string | undefined; -// }; -// }; - -// export type InvokeModelConversionProps = { -// model_name: string; -// save_location: string; -// custom_location: string | null; -// }; - -// export type InvokeModelMergingProps = { -// models_to_merge: string[]; -// alpha: number; -// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'; -// force: boolean; -// merged_model_name: string; -// model_merge_save_path: string | null; -// }; - /** * A disable-able application feature */ diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 780447aba6..403a6cd5c5 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -1,16 +1,11 @@ import { ChakraProps, Flex, + FlexProps, Icon, Image, useColorMode, - useColorModeValue, } from '@chakra-ui/react'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; -import IAIIconButton from 'common/components/IAIIconButton'; import { IAILoadingImageFallback, IAINoContentFallback, @@ -26,22 +21,22 @@ import { useCallback, useState, } from 'react'; -import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; +import { FaImage, FaUpload } from 'react-icons/fa'; import { ImageDTO, PostUploadAction } from 'services/api/types'; import { mode } from 'theme/util/mode'; import IAIDraggable from './IAIDraggable'; import IAIDroppable from './IAIDroppable'; import SelectionOverlay from './SelectionOverlay'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; -type IAIDndImageProps = { +type IAIDndImageProps = FlexProps & { imageDTO: ImageDTO | undefined; onError?: (event: SyntheticEvent) => void; onLoad?: (event: SyntheticEvent) => void; onClick?: (event: MouseEvent) => void; - onClickReset?: (event: MouseEvent) => void; - withResetIcon?: boolean; - resetIcon?: ReactElement; - resetTooltip?: string; withMetadataOverlay?: boolean; isDragDisabled?: boolean; isDropDisabled?: boolean; @@ -58,15 +53,14 @@ type IAIDndImageProps = { noContentFallback?: ReactElement; useThumbailFallback?: boolean; withHoverOverlay?: boolean; + children?: JSX.Element; }; const IAIDndImage = (props: IAIDndImageProps) => { const { imageDTO, - onClickReset, onError, onClick, - withResetIcon = false, withMetadataOverlay = false, isDropDisabled = false, isDragDisabled = false, @@ -80,32 +74,36 @@ const IAIDndImage = (props: IAIDndImageProps) => { dropLabel, isSelected = false, thumbnail = false, - resetTooltip = 'Reset', - resetIcon = , noContentFallback = , useThumbailFallback, withHoverOverlay = false, + children, + onMouseOver, + onMouseOut, } = props; const { colorMode } = useColorMode(); const [isHovered, setIsHovered] = useState(false); - const handleMouseOver = useCallback(() => { - setIsHovered(true); - }, []); - const handleMouseOut = useCallback(() => { - setIsHovered(false); - }, []); + const handleMouseOver = useCallback( + (e: MouseEvent) => { + if (onMouseOver) onMouseOver(e); + setIsHovered(true); + }, + [onMouseOver] + ); + const handleMouseOut = useCallback( + (e: MouseEvent) => { + if (onMouseOut) onMouseOut(e); + setIsHovered(false); + }, + [onMouseOut] + ); const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ postUploadAction, isDisabled: isUploadDisabled, }); - const resetIconShadow = useColorModeValue( - `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`, - `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))` - ); - const uploadButtonStyles = isUploadDisabled ? {} : { @@ -157,11 +155,10 @@ const IAIDndImage = (props: IAIDndImageProps) => { ) } - width={imageDTO.width} - height={imageDTO.height} onError={onError} draggable={false} sx={{ + w: imageDTO.width, objectFit: 'contain', maxW: 'full', maxH: 'full', @@ -220,30 +217,7 @@ const IAIDndImage = (props: IAIDndImageProps) => { dropLabel={dropLabel} /> )} - {onClickReset && withResetIcon && imageDTO && ( - - )} + {children} )} diff --git a/invokeai/frontend/web/src/common/components/IAIDndImageIcon.tsx b/invokeai/frontend/web/src/common/components/IAIDndImageIcon.tsx new file mode 100644 index 0000000000..f3d3fc0dda --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIDndImageIcon.tsx @@ -0,0 +1,46 @@ +import { SystemStyleObject, useColorModeValue } from '@chakra-ui/react'; +import { MouseEvent, ReactElement, memo } from 'react'; +import IAIIconButton from './IAIIconButton'; + +type Props = { + onClick: (event: MouseEvent) => void; + tooltip: string; + icon?: ReactElement; + styleOverrides?: SystemStyleObject; +}; + +const IAIDndImageIcon = (props: Props) => { + const { onClick, tooltip, icon, styleOverrides } = props; + + const resetIconShadow = useColorModeValue( + `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`, + `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))` + ); + return ( + + ); +}; + +export default memo(IAIDndImageIcon); diff --git a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx index 482a8ac604..363799a573 100644 --- a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx @@ -1,22 +1,19 @@ -import { Box } from '@chakra-ui/react'; -import { - TypesafeDraggableData, - useDraggable, -} from 'app/components/ImageDnd/typesafeDnd'; -import { MouseEvent, memo, useRef } from 'react'; +import { Box, BoxProps } from '@chakra-ui/react'; +import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks'; +import { TypesafeDraggableData } from 'features/dnd/types'; +import { memo, useRef } from 'react'; import { v4 as uuidv4 } from 'uuid'; -type IAIDraggableProps = { +type IAIDraggableProps = BoxProps & { disabled?: boolean; data?: TypesafeDraggableData; - onClick?: (event: MouseEvent) => void; }; const IAIDraggable = (props: IAIDraggableProps) => { - const { data, disabled, onClick } = props; + const { data, disabled, ...rest } = props; const dndId = useRef(uuidv4()); - const { attributes, listeners, setNodeRef } = useDraggable({ + const { attributes, listeners, setNodeRef } = useDraggableTypesafe({ id: dndId.current, disabled, data, @@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => { return ( { insetInlineStart={0} {...attributes} {...listeners} + {...rest} /> ); }; diff --git a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx index 1038f36840..e4fb121c78 100644 --- a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx @@ -1,9 +1,7 @@ import { Box } from '@chakra-ui/react'; -import { - TypesafeDroppableData, - isValidDrop, - useDroppable, -} from 'app/components/ImageDnd/typesafeDnd'; +import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks'; +import { TypesafeDroppableData } from 'features/dnd/types'; +import { isValidDrop } from 'features/dnd/util/isValidDrop'; import { AnimatePresence } from 'framer-motion'; import { ReactNode, memo, useRef } from 'react'; import { v4 as uuidv4 } from 'uuid'; @@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => { const { dropLabel, data, disabled } = props; const dndId = useRef(uuidv4()); - const { isOver, setNodeRef, active } = useDroppable({ + const { isOver, setNodeRef, active } = useDroppableTypesafe({ id: dndId.current, disabled, data, diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index 2057525b7a..a150e4ed0c 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => { type IAINoImageFallbackProps = { label?: string; - icon?: As; + icon?: As | null; boxSize?: StyleProps['boxSize']; sx?: ChakraProps['sx']; }; @@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => { ...props.sx, }} > - + {icon && } {props.label && {props.label}} ); diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx index 9803626397..da0883d77e 100644 --- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx +++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx @@ -1,10 +1,13 @@ import { + Flex, FormControl, FormControlProps, + FormHelperText, FormLabel, FormLabelProps, Switch, SwitchProps, + Text, Tooltip, } from '@chakra-ui/react'; import { memo } from 'react'; @@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps { formControlProps?: FormControlProps; formLabelProps?: FormLabelProps; tooltip?: string; + helperText?: string; } /** @@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => { formControlProps, formLabelProps, tooltip, + helperText, ...rest } = props; return ( @@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => { - {label && ( - - {label} - - )} - + + + {label && ( + + {label} + + )} + + + {helperText && ( + + {helperText} + + )} + ); diff --git a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts index 770add7253..0afb7e7e5d 100644 --- a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts +++ b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts @@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => { accent850, accent900, accent950, + baseAlpha50, + baseAlpha100, + baseAlpha150, + baseAlpha200, + baseAlpha250, + baseAlpha300, + baseAlpha350, + baseAlpha400, + baseAlpha450, + baseAlpha500, + baseAlpha550, + baseAlpha600, + baseAlpha650, + baseAlpha700, + baseAlpha750, + baseAlpha800, + baseAlpha850, + baseAlpha900, + baseAlpha950, + accentAlpha50, + accentAlpha100, + accentAlpha150, + accentAlpha200, + accentAlpha250, + accentAlpha300, + accentAlpha350, + accentAlpha400, + accentAlpha450, + accentAlpha500, + accentAlpha550, + accentAlpha600, + accentAlpha650, + accentAlpha700, + accentAlpha750, + accentAlpha800, + accentAlpha850, + accentAlpha900, + accentAlpha950, ] = useToken('colors', [ 'base.50', 'base.100', @@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => { 'accent.850', 'accent.900', 'accent.950', + 'baseAlpha.50', + 'baseAlpha.100', + 'baseAlpha.150', + 'baseAlpha.200', + 'baseAlpha.250', + 'baseAlpha.300', + 'baseAlpha.350', + 'baseAlpha.400', + 'baseAlpha.450', + 'baseAlpha.500', + 'baseAlpha.550', + 'baseAlpha.600', + 'baseAlpha.650', + 'baseAlpha.700', + 'baseAlpha.750', + 'baseAlpha.800', + 'baseAlpha.850', + 'baseAlpha.900', + 'baseAlpha.950', + 'accentAlpha.50', + 'accentAlpha.100', + 'accentAlpha.150', + 'accentAlpha.200', + 'accentAlpha.250', + 'accentAlpha.300', + 'accentAlpha.350', + 'accentAlpha.400', + 'accentAlpha.450', + 'accentAlpha.500', + 'accentAlpha.550', + 'accentAlpha.600', + 'accentAlpha.650', + 'accentAlpha.700', + 'accentAlpha.750', + 'accentAlpha.800', + 'accentAlpha.850', + 'accentAlpha.900', + 'accentAlpha.950', ]); return { @@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => { accent850, accent900, accent950, + baseAlpha50, + baseAlpha100, + baseAlpha150, + baseAlpha200, + baseAlpha250, + baseAlpha300, + baseAlpha350, + baseAlpha400, + baseAlpha450, + baseAlpha500, + baseAlpha550, + baseAlpha600, + baseAlpha650, + baseAlpha700, + baseAlpha750, + baseAlpha800, + baseAlpha850, + baseAlpha900, + baseAlpha950, + accentAlpha50, + accentAlpha100, + accentAlpha150, + accentAlpha200, + accentAlpha250, + accentAlpha300, + accentAlpha350, + accentAlpha400, + accentAlpha450, + accentAlpha500, + accentAlpha550, + accentAlpha600, + accentAlpha650, + accentAlpha700, + accentAlpha750, + accentAlpha800, + accentAlpha850, + accentAlpha900, + accentAlpha950, }; }; diff --git a/invokeai/frontend/web/src/common/util/serialize.ts b/invokeai/frontend/web/src/common/util/serialize.ts index a9352a8228..a5db921f8d 100644 --- a/invokeai/frontend/web/src/common/util/serialize.ts +++ b/invokeai/frontend/web/src/common/util/serialize.ts @@ -1,4 +1,10 @@ /** * Serialize an object to JSON and back to a new object */ -export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj)); +export const parseify = (obj: unknown) => { + try { + return JSON.parse(JSON.stringify(obj)); + } catch { + return 'Error parsing object'; + } +}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx index 2f74e5542a..25ef295631 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx @@ -2,10 +2,11 @@ import { ButtonGroup, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAIColorPicker from 'common/components/IAIColorPicker'; import IAIIconButton from 'common/components/IAIIconButton'; import IAIPopover from 'common/components/IAIPopover'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; +import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; import { canvasSelector, isStagingSelector, @@ -22,7 +23,7 @@ import { isEqual } from 'lodash-es'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; -import { FaMask, FaTrash } from 'react-icons/fa'; +import { FaMask, FaSave, FaTrash } from 'react-icons/fa'; export const selector = createSelector( [canvasSelector, isStagingSelector], @@ -102,6 +103,10 @@ const IAICanvasMaskOptions = () => { const handleToggleEnableMask = () => dispatch(setIsMaskEnabled(!isMaskEnabled)); + const handleSaveMask = async () => { + dispatch(canvasMaskSavedToGallery()); + }; + return ( { pickerColor={maskColor} onChange={(newColor) => dispatch(setMaskColor(newColor))} /> + } onClick={handleSaveMask}> + Save Mask + } onClick={handleClearMask}> {t('unifiedCanvas.clearMask')} (Shift+C) diff --git a/invokeai/frontend/web/src/features/canvas/store/actions.ts b/invokeai/frontend/web/src/features/canvas/store/actions.ts index 1f491874a0..b4efa76e42 100644 --- a/invokeai/frontend/web/src/features/canvas/store/actions.ts +++ b/invokeai/frontend/web/src/features/canvas/store/actions.ts @@ -3,6 +3,10 @@ import { ImageDTO } from 'services/api/types'; export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); +export const canvasMaskSavedToGallery = createAction( + 'canvas/canvasMaskSavedToGallery' +); + export const canvasCopiedToClipboard = createAction( 'canvas/canvasCopiedToClipboard' ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index cdab176cd2..0683282811 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -4,14 +4,16 @@ import { skipToken } from '@reduxjs/toolkit/dist/query'; import { TypesafeDraggableData, TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; +} from 'features/dnd/types'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import { memo, useCallback, useMemo, useState } from 'react'; +import { FaUndo } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; +import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon'; import { ControlNetConfig, controlNetImageChanged, @@ -119,11 +121,15 @@ const ControlNetImagePreview = (props: Props) => { droppableData={droppableData} imageDTO={controlImage} isDropDisabled={shouldShowProcessedImage || !isEnabled} - onClickReset={handleResetControlImage} postUploadAction={postUploadAction} - resetTooltip="Reset Control Image" - withResetIcon={Boolean(controlImage)} - /> + > + : undefined} + tooltip="Reset Control Image" + /> + + { imageDTO={processedControlImage} isUploadDisabled={true} isDropDisabled={!isEnabled} - onClickReset={handleResetControlImage} - resetTooltip="Reset Control Image" - withResetIcon={Boolean(controlImage)} - /> + > + : undefined} + tooltip="Reset Control Image" + /> + {pendingControlImages.includes(controlNetId) && ( ; /** * Type guard for CannyImageProcessorInvocation diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index 310521f32a..37be06bad6 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -3,6 +3,7 @@ import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { some } from 'lodash-es'; import { ImageUsage } from './types'; +import { isInvocationNode } from 'features/nodes/types/types'; export const getImageUsage = (state: RootState, image_name: string) => { const { generation, canvas, nodes, controlNet } = state; @@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => { (obj) => obj.kind === 'image' && obj.imageName === image_name ); - const isNodesImage = nodes.nodes.some((node) => { + const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => { return some( node.data.inputs, (input) => - input.type === 'image' && input.value?.image_name === image_name + input.type === 'ImageField' && input.value?.image_name === image_name ); }); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx similarity index 70% rename from invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx rename to invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx index 56eeb9b5db..bffe738aa9 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx +++ b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx @@ -6,23 +6,18 @@ import { useSensor, useSensors, } from '@dnd-kit/core'; -import { snapCenterToCursor } from '@dnd-kit/modifiers'; +import { logger } from 'app/logging/logger'; import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped'; import { useAppDispatch } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo, useCallback, useState } from 'react'; +import { useScaledModifer } from '../hooks/useScaledCenteredModifer'; +import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types'; +import { DndContextTypesafe } from './DndContextTypesafe'; import DragPreview from './DragPreview'; -import { - DndContext, - DragEndEvent, - DragStartEvent, - TypesafeDraggableData, -} from './typesafeDnd'; -import { logger } from 'app/logging/logger'; -type ImageDndContextProps = PropsWithChildren; - -const ImageDndContext = (props: ImageDndContextProps) => { +const AppDndContext = (props: PropsWithChildren) => { const [activeDragData, setActiveDragData] = useState(null); const log = logger('images'); @@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => { const handleDragStart = useCallback( (event: DragStartEvent) => { - log.trace({ dragData: event.active.data.current }, 'Drag started'); + log.trace( + { dragData: parseify(event.active.data.current) }, + 'Drag started' + ); const activeData = event.active.data.current; if (!activeData) { return; @@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => { const handleDragEnd = useCallback( (event: DragEndEvent) => { - log.trace({ dragData: event.active.data.current }, 'Drag ended'); + log.trace( + { dragData: parseify(event.active.data.current) }, + 'Drag ended' + ); const overData = event.over?.data.current; if (!activeDragData || !overData) { return; @@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => { const sensors = useSensors(mouseSensor, touchSensor); + const scaledModifier = useScaledModifer(); + return ( - {props.children} - + {activeDragData && ( { )} - + ); }; -export default memo(ImageDndContext); +export default memo(AppDndContext); diff --git a/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx new file mode 100644 index 0000000000..06fede4dc8 --- /dev/null +++ b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx @@ -0,0 +1,6 @@ +import { DndContext } from '@dnd-kit/core'; +import { DndContextTypesafeProps } from '../types'; + +export function DndContextTypesafe(props: DndContextTypesafeProps) { + return ; +} diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx similarity index 69% rename from invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx rename to invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx index c97778ffcd..0ee5d34b1a 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx @@ -1,6 +1,6 @@ -import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react'; import { memo } from 'react'; -import { TypesafeDraggableData } from './typesafeDnd'; +import { TypesafeDraggableData } from '../types'; type OverlayDragImageProps = { dragData: TypesafeDraggableData | null; @@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => { return null; } + if (props.dragData.payloadType === 'NODE_FIELD') { + const { field, fieldTemplate } = props.dragData.payload; + return ( + + {field.label || fieldTemplate.title} + + ); + } + if (props.dragData.payloadType === 'IMAGE_DTO') { const { thumbnail_url, width, height } = props.dragData.payload.imageDTO; return ( { return ( + activeTabName === 'nodes' ? nodes.viewport.zoom : 1 +); + +/** + * Applies scaling to the drag transform (if on node editor tab) and centers it on cursor. + */ +export const useScaledModifer = () => { + const zoom = useAppSelector(selectZoom); + const modifier: Modifier = useCallback( + ({ activatorEvent, draggingNodeRect, transform }) => { + if (draggingNodeRect && activatorEvent) { + const activatorCoordinates = getEventCoordinates(activatorEvent); + + if (!activatorCoordinates) { + return transform; + } + + const offsetX = activatorCoordinates.x - draggingNodeRect.left; + const offsetY = activatorCoordinates.y - draggingNodeRect.top; + + const x = transform.x + offsetX - draggingNodeRect.width / 2; + const y = transform.y + offsetY - draggingNodeRect.height / 2; + const scaleX = transform.scaleX * zoom; + const scaleY = transform.scaleY * zoom; + + return { + x, + y, + scaleX, + scaleY, + }; + } + + return transform; + }, + [zoom] + ); + + return modifier; +}; diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/features/dnd/types/index.ts similarity index 51% rename from invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx rename to invokeai/frontend/web/src/features/dnd/types/index.ts index 6f24302070..294132d0a3 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/features/dnd/types/index.ts @@ -3,7 +3,6 @@ import { Active, Collision, DndContextProps, - DndContext as OriginalDndContext, Over, Translate, UseDraggableArguments, @@ -11,6 +10,10 @@ import { useDraggable as useOriginalDraggable, useDroppable as useOriginalDroppable, } from '@dnd-kit/core'; +import { + InputFieldTemplate, + InputFieldValue, +} from 'features/nodes/types/types'; import { ImageDTO } from 'services/api/types'; type BaseDropData = { @@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & { actionType: 'REMOVE_FROM_BOARD'; }; +export type AddFieldToLinearViewDropData = BaseDropData & { + actionType: 'ADD_FIELD_TO_LINEAR'; +}; + export type TypesafeDroppableData = | CurrentImageDropData | InitialImageDropData @@ -71,12 +78,22 @@ export type TypesafeDroppableData = | AddToBatchDropData | NodesMultiImageDropData | AddToBoardDropData - | RemoveFromBoardDropData; + | RemoveFromBoardDropData + | AddFieldToLinearViewDropData; type BaseDragData = { id: string; }; +export type NodeFieldDraggableData = BaseDragData & { + payloadType: 'NODE_FIELD'; + payload: { + nodeId: string; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + }; +}; + export type ImageDraggableData = BaseDragData & { payloadType: 'IMAGE_DTO'; payload: { imageDTO: ImageDTO }; @@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & { payload: { imageDTOs: ImageDTO[] }; }; -export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData; +export type TypesafeDraggableData = + | NodeFieldDraggableData + | ImageDraggableData + | ImageDTOsDraggableData; -interface UseDroppableTypesafeArguments +export interface UseDroppableTypesafeArguments extends Omit { data?: TypesafeDroppableData; } -type UseDroppableTypesafeReturnValue = Omit< +export type UseDroppableTypesafeReturnValue = Omit< ReturnType, 'active' | 'over' > & { @@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit< over: TypesafeOver | null; }; -export function useDroppable(props: UseDroppableTypesafeArguments) { - return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue; -} - -interface UseDraggableTypesafeArguments +export interface UseDraggableTypesafeArguments extends Omit { data?: TypesafeDraggableData; } -type UseDraggableTypesafeReturnValue = Omit< +export type UseDraggableTypesafeReturnValue = Omit< ReturnType, 'active' | 'over' > & { @@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit< over: TypesafeOver | null; }; -export function useDraggable(props: UseDraggableTypesafeArguments) { - return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue; -} - -interface TypesafeActive extends Omit { +export interface TypesafeActive extends Omit { data: React.MutableRefObject; } -interface TypesafeOver extends Omit { +export interface TypesafeOver extends Omit { data: React.MutableRefObject; } -export const isValidDrop = ( - overData: TypesafeDroppableData | undefined, - active: TypesafeActive | null -) => { - if (!overData || !active?.data.current) { - return false; - } - - const { actionType } = overData; - const { payloadType } = active.data.current; - - if (overData.id === active.data.current.id) { - return false; - } - - switch (actionType) { - case 'SET_CURRENT_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_INITIAL_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_CONTROLNET_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_CANVAS_INITIAL_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_NODES_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - case 'ADD_TO_BOARD': { - // If the board is the same, don't allow the drop - - // Check the payload types - const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - if (!isPayloadValid) { - return false; - } - - // Check if the image's board is the board we are dragging onto - if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; - const currentBoard = imageDTO.board_id ?? 'none'; - const destinationBoard = overData.context.boardId; - - return currentBoard !== destinationBoard; - } - - if (payloadType === 'IMAGE_DTOS') { - // TODO (multi-select) - return true; - } - - return false; - } - case 'REMOVE_FROM_BOARD': { - // If the board is the same, don't allow the drop - - // Check the payload types - const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - if (!isPayloadValid) { - return false; - } - - // Check if the image's board is the board we are dragging onto - if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; - const currentBoard = imageDTO.board_id; - - return currentBoard !== 'none'; - } - - if (payloadType === 'IMAGE_DTOS') { - // TODO (multi-select) - return true; - } - - return false; - } - default: - return false; - } -}; - interface DragEvent { activatorEvent: Event; active: TypesafeActive; @@ -240,6 +168,3 @@ export interface DndContextTypesafeProps onDragEnd?(event: DragEndEvent): void; onDragCancel?(event: DragCancelEvent): void; } -export function DndContext(props: DndContextTypesafeProps) { - return ; -} diff --git a/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts new file mode 100644 index 0000000000..f704d22dff --- /dev/null +++ b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts @@ -0,0 +1,87 @@ +import { TypesafeActive, TypesafeDroppableData } from '../types'; + +export const isValidDrop = ( + overData: TypesafeDroppableData | undefined, + active: TypesafeActive | null +) => { + if (!overData || !active?.data.current) { + return false; + } + + const { actionType } = overData; + const { payloadType } = active.data.current; + + if (overData.id === active.data.current.id) { + return false; + } + + switch (actionType) { + case 'ADD_FIELD_TO_LINEAR': + return payloadType === 'NODE_FIELD'; + case 'SET_CURRENT_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CONTROLNET_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CANVAS_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_NODES_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_MULTI_NODES_IMAGE': + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + case 'ADD_TO_BATCH': + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + case 'ADD_TO_BOARD': { + // If the board is the same, don't allow the drop + + // Check the payload types + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + if (!isPayloadValid) { + return false; + } + + // Check if the image's board is the board we are dragging onto + if (payloadType === 'IMAGE_DTO') { + const { imageDTO } = active.data.current.payload; + const currentBoard = imageDTO.board_id ?? 'none'; + const destinationBoard = overData.context.boardId; + + return currentBoard !== destinationBoard; + } + + if (payloadType === 'IMAGE_DTOS') { + // TODO (multi-select) + return true; + } + + return false; + } + case 'REMOVE_FROM_BOARD': { + // If the board is the same, don't allow the drop + + // Check the payload types + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + if (!isPayloadValid) { + return false; + } + + // Check if the image's board is the board we are dragging onto + if (payloadType === 'IMAGE_DTO') { + const { imageDTO } = active.data.current.payload; + const currentBoard = imageDTO.board_id; + + return currentBoard !== 'none'; + } + + if (payloadType === 'IMAGE_DTOS') { + // TODO (multi-select) + return true; + } + + return false; + } + default: + return false; + } +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 228ce7080c..696a8b748b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -11,7 +11,6 @@ import { } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; @@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { BoardDTO } from 'services/api/types'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; +import { AddToBoardDropData } from 'features/dnd/types'; interface GalleryBoardProps { board: BoardDTO; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx index 0d630c524d..1698a81ac0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx @@ -1,7 +1,7 @@ import { As, Badge, Flex } from '@chakra-ui/react'; -import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd'; import IAIDroppable from 'common/components/IAIDroppable'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { TypesafeDroppableData } from 'features/dnd/types'; import { BoardId } from 'features/gallery/store/types'; import { ReactNode } from 'react'; import BoardContextMenu from '../BoardContextMenu'; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index f1341b1146..fec280db0f 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -1,15 +1,15 @@ import { Box, Flex, Image, Text } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import InvokeAILogoImage from 'assets/images/logo.png'; import IAIDroppable from 'common/components/IAIDroppable'; import SelectionOverlay from 'common/components/SelectionOverlay'; +import { RemoveFromBoardDropData } from 'features/dnd/types'; import { - boardIdSelected, autoAddBoardIdChanged, + boardIdSelected, } from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo, useState } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx index f78ee286ef..2576c8e9e3 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx @@ -1,14 +1,14 @@ import { Box, Flex, Image } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { AnimatePresence, motion } from 'framer-motion'; diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx index 5c32cc788e..23cfdcc5fd 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx @@ -11,7 +11,6 @@ import { autoAssignBoardOnClickChanged, setGalleryImageMinimumWidth, shouldAutoSwitchChanged, - shouldShowDeleteButtonChanged, } from 'features/gallery/store/gallerySlice'; import { ChangeEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -26,14 +25,12 @@ const selector = createSelector( galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick, - shouldShowDeleteButton, } = state.gallery; return { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick, - shouldShowDeleteButton, }; }, defaultSelectorOptions @@ -43,12 +40,8 @@ const GallerySettingsPopover = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { - galleryImageMinimumWidth, - shouldAutoSwitch, - autoAssignBoardOnClick, - shouldShowDeleteButton, - } = useAppSelector(selector); + const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } = + useAppSelector(selector); const handleChangeGalleryImageMinimumWidth = useCallback( (v: number) => { @@ -68,13 +61,6 @@ const GallerySettingsPopover = () => { [dispatch] ); - const handleChangeShowDeleteButton = useCallback( - (e: ChangeEvent) => { - dispatch(shouldShowDeleteButtonChanged(e.target.checked)); - }, - [dispatch] - ); - return ( { { isChecked={shouldAutoSwitch} onChange={handleChangeAutoSwitch} /> - { const dispatch = useAppDispatch(); const selection = useAppSelector((state) => state.gallery.selection); + const [starImages] = useStarImagesMutation(); + const [unstarImages] = useUnstarImagesMutation(); + const handleChangeBoard = useCallback(() => { dispatch(imagesToChangeSelected(selection)); dispatch(isModalOpenChanged(true)); @@ -21,8 +29,37 @@ const MultipleSelectionMenuItems = () => { dispatch(imagesToDeleteSelected(selection)); }, [dispatch, selection]); + const handleStarSelection = useCallback(() => { + starImages({ imageDTOs: selection }); + }, [starImages, selection]); + + const handleUnstarSelection = useCallback(() => { + unstarImages({ imageDTOs: selection }); + }, [unstarImages, selection]); + + const areAllStarred = useMemo(() => { + return selection.every((img) => img.starred); + }, [selection]); + + const areAllUnstarred = useMemo(() => { + return selection.every((img) => !img.starred); + }, [selection]); + return ( <> + {areAllStarred && ( + } + onClickCapture={handleUnstarSelection} + > + Unstar All + + )} + {(areAllUnstarred || (!areAllStarred && !areAllUnstarred)) && ( + } onClickCapture={handleStarSelection}> + Star All + + )} } onClickCapture={handleChangeBoard}> Change Board diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index f857abf5ff..ef6e2ccd5c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -29,10 +29,15 @@ import { FaShare, FaTrash, } from 'react-icons/fa'; -import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; +import { + useGetImageMetadataQuery, + useStarImagesMutation, + useUnstarImagesMutation, +} from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { useDebounce } from 'use-debounce'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; +import { MdStar, MdStarBorder } from 'react-icons/md'; type SingleSelectionMenuItemsProps = { imageDTO: ImageDTO; @@ -59,6 +64,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { : debouncedMetadataQueryArg ?? skipToken ); + const [starImages] = useStarImagesMutation(); + const [unstarImages] = useUnstarImagesMutation(); + const { isClipboardAPIAvailable, copyImageToClipboard } = useCopyImageToClipboard(); @@ -127,6 +135,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { copyImageToClipboard(imageDTO.image_url); }, [copyImageToClipboard, imageDTO.image_url]); + const handleStarImage = useCallback(() => { + if (imageDTO) starImages({ imageDTOs: [imageDTO] }); + }, [starImages, imageDTO]); + + const handleUnstarImage = useCallback(() => { + if (imageDTO) unstarImages({ imageDTOs: [imageDTO] }); + }, [unstarImages, imageDTO]); + return ( <> { } onClickCapture={handleChangeBoard}> Change Board + {imageDTO.starred ? ( + } onClickCapture={handleUnstarImage}> + Unstar Image + + ) : ( + } onClickCapture={handleStarImage}> + Star Image + + )} } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index f2ff2ad30b..804df49b8e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -52,11 +52,13 @@ const ImageGalleryContent = () => { return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index c9eee5f1f5..5dbbf011e8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -1,17 +1,23 @@ import { Box, Flex } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIDndImage from 'common/components/IAIDndImage'; +import IAIFillSkeleton from 'common/components/IAIFillSkeleton'; +import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { ImageDTOsDraggableData, ImageDraggableData, TypesafeDraggableData, -} from 'app/components/ImageDnd/typesafeDnd'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIDndImage from 'common/components/IAIDndImage'; -import IAIFillSkeleton from 'common/components/IAIFillSkeleton'; +} from 'features/dnd/types'; import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts'; -import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; -import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { MouseEvent, memo, useCallback, useMemo, useState } from 'react'; import { FaTrash } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { MdStar, MdStarBorder } from 'react-icons/md'; +import { + useGetImageDTOQuery, + useStarImagesMutation, + useUnstarImagesMutation, +} from 'services/api/endpoints/images'; +import IAIDndImageIcon from '../../../../common/components/IAIDndImageIcon'; interface HoverableImageProps { imageName: string; @@ -21,9 +27,7 @@ const GalleryImage = (props: HoverableImageProps) => { const dispatch = useAppDispatch(); const { imageName } = props; const { currentData: imageDTO } = useGetImageDTOQuery(imageName); - const shouldShowDeleteButton = useAppSelector( - (state) => state.gallery.shouldShowDeleteButton - ); + const shift = useAppSelector((state) => state.hotkeys.shift); const { handleClick, isSelected, selection, selectionCount } = useMultiselect(imageDTO); @@ -59,6 +63,35 @@ const GalleryImage = (props: HoverableImageProps) => { } }, [imageDTO, selection, selectionCount]); + const [starImages] = useStarImagesMutation(); + const [unstarImages] = useUnstarImagesMutation(); + + const toggleStarredState = useCallback(() => { + if (imageDTO) { + if (imageDTO.starred) { + unstarImages({ imageDTOs: [imageDTO] }); + } + if (!imageDTO.starred) { + starImages({ imageDTOs: [imageDTO] }); + } + } + }, [starImages, unstarImages, imageDTO]); + + const [isHovered, setIsHovered] = useState(false); + + const handleMouseOver = useCallback(() => { + setIsHovered(true); + }, []); + + const handleMouseOut = useCallback(() => { + setIsHovered(false); + }, []); + + const starIcon = useMemo(() => { + if (imageDTO?.starred) return ; + if (!imageDTO?.starred && isHovered) return ; + }, [imageDTO?.starred, isHovered]); + if (!imageDTO) { return ; } @@ -80,16 +113,34 @@ const GalleryImage = (props: HoverableImageProps) => { draggableData={draggableData} isSelected={isSelected} minSize={0} - onClickReset={handleDelete} imageSx={{ w: 'full', h: 'full' }} isDropDisabled={true} isUploadDisabled={true} thumbnail={true} withHoverOverlay - resetIcon={} - resetTooltip="Delete image" - withResetIcon={shouldShowDeleteButton} // removed bc it's too easy to accidentally delete images - /> + onMouseOver={handleMouseOver} + onMouseOut={handleMouseOut} + > + <> + + + {isHovered && shift && ( + } + tooltip="Delete" + styleOverrides={{ + bottom: 2, + top: 'auto', + }} + /> + )} + + ); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx index 4a56fe0e9a..bacd5c38ad 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx @@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = { options: { scrollbars: { visibility: 'auto', - autoHide: 'leave', + autoHide: 'scroll', autoHideDelay: 1300, theme: 'os-theme-dark', }, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx index 590d40438b..69385607de 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx @@ -1,26 +1,40 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; -import { useMemo } from 'react'; -import { FaCopy } from 'react-icons/fa'; +import { useCallback, useMemo } from 'react'; +import { FaCopy, FaSave } from 'react-icons/fa'; type Props = { - copyTooltip: string; + label: string; jsonObject: object; + fileName?: string; }; const ImageMetadataJSON = (props: Props) => { - const { copyTooltip, jsonObject } = props; + const { label, jsonObject, fileName } = props; const jsonString = useMemo( () => JSON.stringify(jsonObject, null, 2), [jsonObject] ); + const handleCopy = useCallback(() => { + navigator.clipboard.writeText(jsonString); + }, [jsonString]); + + const handleSave = useCallback(() => { + const blob = new Blob([jsonString]); + const a = document.createElement('a'); + a.href = URL.createObjectURL(blob); + a.download = `${fileName || label}.json`; + document.body.appendChild(a); + a.click(); + a.remove(); + }, [jsonString, label, fileName]); + return ( { bottom: 0, overflow: 'auto', p: 4, + fontSize: 'sm', }} > { options={{ scrollbars: { visibility: 'auto', - autoHide: 'move', + autoHide: 'scroll', autoHideDelay: 1300, theme: 'os-theme-dark', }, @@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => { - + } + variant="ghost" + opacity={0.7} + onClick={handleSave} + /> + + + } variant="ghost" - onClick={() => navigator.clipboard.writeText(jsonString)} + opacity={0.7} + onClick={handleCopy} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx index e1f2a9e46a..d70aea8a8d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx @@ -10,7 +10,8 @@ import { Text, } from '@chakra-ui/react'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { memo, useMemo } from 'react'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { memo } from 'react'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { useDebounce } from 'use-debounce'; @@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const metadata = currentData?.metadata; const graph = currentData?.graph; - const tabData = useMemo(() => { - const _tabData: { label: string; data: object; copyTooltip: string }[] = []; - - if (metadata) { - _tabData.push({ - label: 'Core Metadata', - data: metadata, - copyTooltip: 'Copy Core Metadata JSON', - }); - } - - if (image) { - _tabData.push({ - label: 'Image Details', - data: image, - copyTooltip: 'Copy Image Details JSON', - }); - } - - if (graph) { - _tabData.push({ - label: 'Graph', - data: graph, - copyTooltip: 'Copy Graph JSON', - }); - } - return _tabData; - }, [metadata, graph, image]); - return ( { sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }} > - {tabData.map((tab) => ( - - - {tab.label} - - - ))} + Core Metadata + Image Details + Graph - - {tabData.map((tab) => ( - - - - ))} + + + {metadata ? ( + + ) : ( + + )} + + + {image ? ( + + ) : ( + + )} + + + {graph ? ( + + ) : ( + + )} + diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index bc7acff6f4..a4e4b02937 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -13,7 +13,6 @@ export const initialGalleryState: GalleryState = { galleryImageMinimumWidth: 96, selectedBoardId: 'none', galleryView: 'images', - shouldShowDeleteButton: false, boardSearchText: '', }; @@ -50,9 +49,6 @@ export const gallerySlice = createSlice({ galleryViewChanged: (state, action: PayloadAction) => { state.galleryView = action.payload; }, - shouldShowDeleteButtonChanged: (state, action: PayloadAction) => { - state.shouldShowDeleteButton = action.payload; - }, boardSearchTextChanged: (state, action: PayloadAction) => { state.boardSearchText = action.payload; }, @@ -93,7 +89,6 @@ export const { autoAddBoardIdChanged, galleryViewChanged, selectionChanged, - shouldShowDeleteButtonChanged, boardSearchTextChanged, } = gallerySlice.actions; diff --git a/invokeai/frontend/web/src/features/gallery/store/types.ts b/invokeai/frontend/web/src/features/gallery/store/types.ts index 6860f6ea7b..7b707dd303 100644 --- a/invokeai/frontend/web/src/features/gallery/store/types.ts +++ b/invokeai/frontend/web/src/features/gallery/store/types.ts @@ -21,6 +21,5 @@ export type GalleryState = { galleryImageMinimumWidth: number; selectedBoardId: BoardId; galleryView: GalleryView; - shouldShowDeleteButton: boolean; boardSearchText: string; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx index a1a1acf1f8..a816762d0f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx @@ -9,30 +9,40 @@ import { map } from 'lodash-es'; import { forwardRef, useCallback } from 'react'; import 'reactflow/dist/style.css'; import { AnyInvocationType } from 'services/events/types'; -import { useBuildInvocation } from '../hooks/useBuildInvocation'; +import { useBuildNodeData } from '../hooks/useBuildNodeData'; import { nodeAdded } from '../store/nodesSlice'; type NodeTemplate = { label: string; value: string; description: string; + tags: string[]; }; const selector = createSelector( [stateSelector], ({ nodes }) => { - const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => { + const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { return { label: template.title, value: template.type, description: template.description, + tags: template.tags, }; }); data.push({ label: 'Progress Image', - value: 'progress_image', - description: 'Displays the progress image in the Node Editor', + value: 'current_image', + description: 'Displays the current image in the Node Editor', + tags: ['progress'], + }); + + data.push({ + label: 'Notes', + value: 'notes', + description: 'Add notes about your workflow', + tags: ['notes'], }); return { data }; @@ -44,7 +54,7 @@ const AddNodeMenu = () => { const dispatch = useAppDispatch(); const { data } = useAppSelector(selector); - const buildInvocation = useBuildInvocation(); + const buildInvocation = useBuildNodeData(); const toaster = useAppToaster(); @@ -89,11 +99,12 @@ const AddNodeMenu = () => { filter={(value, item: NodeTemplate) => item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.value.toLowerCase().includes(value.toLowerCase().trim()) || - item.description.toLowerCase().includes(value.toLowerCase().trim()) + item.description.toLowerCase().includes(value.toLowerCase().trim()) || + item.tags.includes(value.toLowerCase().trim()) } onChange={handleChange} sx={{ - width: '18rem', + width: '24rem', }} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx new file mode 100644 index 0000000000..678d8e3d1d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx @@ -0,0 +1,61 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; +import { FIELDS, colorTokenToCssVar } from '../types/constants'; + +const selector = createSelector(stateSelector, ({ nodes }) => { + const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = + nodes; + + const stroke = + currentConnectionFieldType && shouldColorEdges + ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color) + : colorTokenToCssVar('base.500'); + + let className = 'react-flow__custom_connection-path'; + + if (shouldAnimateEdges) { + className = className.concat(' animated'); + } + + return { + stroke, + className, + }; +}); + +export const CustomConnectionLine = ({ + fromX, + fromY, + fromPosition, + toX, + toY, + toPosition, +}: ConnectionLineComponentProps) => { + const { stroke, className } = useAppSelector(selector); + + const pathParams = { + sourceX: fromX, + sourceY: fromY, + sourcePosition: fromPosition, + targetX: toX, + targetY: toY, + targetPosition: toPosition, + }; + + const [dAttr] = getBezierPath(pathParams); + + return ( + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx new file mode 100644 index 0000000000..e0ccc6e323 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx @@ -0,0 +1,183 @@ +import { Badge, Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { useMemo } from 'react'; +import { + BaseEdge, + EdgeLabelRenderer, + EdgeProps, + getBezierPath, +} from 'reactflow'; +import { FIELDS, colorTokenToCssVar } from '../types/constants'; +import { isInvocationNode } from '../types/types'; + +const makeEdgeSelector = ( + source: string, + sourceHandleId: string | null | undefined, + target: string, + targetHandleId: string | null | undefined, + selected?: boolean +) => + createSelector(stateSelector, ({ nodes }) => { + const sourceNode = nodes.nodes.find((node) => node.id === source); + const targetNode = nodes.nodes.find((node) => node.id === target); + + const isInvocationToInvocationEdge = + isInvocationNode(sourceNode) && isInvocationNode(targetNode); + + const isSelected = sourceNode?.selected || targetNode?.selected || selected; + const sourceType = isInvocationToInvocationEdge + ? sourceNode?.data?.outputs[sourceHandleId || '']?.type + : undefined; + + const stroke = + sourceType && nodes.shouldColorEdges + ? colorTokenToCssVar(FIELDS[sourceType].color) + : colorTokenToCssVar('base.500'); + + return { + isSelected, + shouldAnimate: nodes.shouldAnimateEdges && isSelected, + stroke, + }; + }); + +const CollapsedEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + data, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps<{ count: number }>) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [selected, source, sourceHandleId, target, targetHandleId] + ); + + const { isSelected, shouldAnimate } = useAppSelector(selector); + + const [edgePath, labelX, labelY] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + const { base500 } = useChakraThemeTokens(); + + return ( + <> + + {data?.count && data.count > 1 && ( + + + + {data.count} + + + + )} + + ); +}; + +const DefaultEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [source, sourceHandleId, target, targetHandleId, selected] + ); + + const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); + + const [edgePath] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + return ( + + ); +}; + +export const edgeTypes = { + collapsed: CollapsedEdge, + default: DefaultEdge, +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx new file mode 100644 index 0000000000..be845df435 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx @@ -0,0 +1,9 @@ +import CurrentImageNode from './nodes/CurrentImageNode'; +import InvocationNodeWrapper from './nodes/InvocationNodeWrapper'; +import NotesNode from './nodes/NotesNode'; + +export const nodeTypes = { + invocation: InvocationNodeWrapper, + current_image: CurrentImageNode, + notes: NotesNode, +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx deleted file mode 100644 index 86099a7315..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import { Tooltip } from '@chakra-ui/react'; -import { CSSProperties, memo } from 'react'; -import { Handle, Position, Connection, HandleType } from 'reactflow'; -import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants'; -// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles'; -import { InputFieldTemplate, OutputFieldTemplate } from '../types/types'; - -const handleBaseStyles: CSSProperties = { - position: 'absolute', - width: '1rem', - height: '1rem', - borderWidth: 0, -}; - -const inputHandleStyles: CSSProperties = { - left: '-1rem', -}; - -const outputHandleStyles: CSSProperties = { - right: '-0.5rem', -}; - -// const requiredConnectionStyles: CSSProperties = { -// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)', -// }; - -type FieldHandleProps = { - nodeId: string; - field: InputFieldTemplate | OutputFieldTemplate; - isValidConnection: (connection: Connection) => boolean; - handleType: HandleType; - styles?: CSSProperties; -}; - -const FieldHandle = (props: FieldHandleProps) => { - const { field, isValidConnection, handleType, styles } = props; - const { name, type } = field; - - return ( - - - - ); -}; - -export default memo(FieldHandle); diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx index 78316cc694..a523cc29fe 100644 --- a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx @@ -1,8 +1,8 @@ -import 'reactflow/dist/style.css'; -import { Tooltip, Badge, Flex } from '@chakra-ui/react'; +import { Badge, Flex, Tooltip } from '@chakra-ui/react'; import { map } from 'lodash-es'; -import { FIELDS } from '../types/constants'; import { memo } from 'react'; +import 'reactflow/dist/style.css'; +import { FIELDS } from '../types/constants'; const FieldTypeLegend = () => { return ( @@ -10,8 +10,14 @@ const FieldTypeLegend = () => { {map(FIELDS, ({ title, description, color }, key) => ( {title} diff --git a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx index 7b0718182b..8234a6a7fa 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx @@ -1,4 +1,4 @@ -import { RootState } from 'app/store/store'; +import { useToken } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { @@ -7,35 +7,51 @@ import { OnConnectEnd, OnConnectStart, OnEdgesChange, - OnInit, + OnEdgesDelete, + OnMoveEnd, OnNodesChange, + OnNodesDelete, + OnSelectionChangeFunc, + ProOptions, ReactFlow, } from 'reactflow'; +import { useIsValidConnection } from '../hooks/useIsValidConnection'; import { connectionEnded, connectionMade, connectionStarted, edgesChanged, + edgesDeleted, nodesChanged, - setEditorInstance, + nodesDeleted, + selectedEdgesChanged, + selectedNodesChanged, + viewportChanged, } from '../store/nodesSlice'; -import { InvocationComponent } from './InvocationComponent'; -import ProgressImageNode from './ProgressImageNode'; -import BottomLeftPanel from './panels/BottomLeftPanel.tsx'; -import MinimapPanel from './panels/MinimapPanel'; -import TopCenterPanel from './panels/TopCenterPanel'; -import TopLeftPanel from './panels/TopLeftPanel'; -import TopRightPanel from './panels/TopRightPanel'; +import { CustomConnectionLine } from './CustomConnectionLine'; +import { edgeTypes } from './CustomEdges'; +import { nodeTypes } from './CustomNodes'; +import BottomLeftPanel from './editorPanels/BottomLeftPanel'; +import MinimapPanel from './editorPanels/MinimapPanel'; +import TopCenterPanel from './editorPanels/TopCenterPanel'; +import TopLeftPanel from './editorPanels/TopLeftPanel'; +import TopRightPanel from './editorPanels/TopRightPanel'; -const nodeTypes = { - invocation: InvocationComponent, - progress_image: ProgressImageNode, -}; +// TODO: can we support reactflow? if not, we could style the attribution so it matches the app +const proOptions: ProOptions = { hideAttribution: true }; export const Flow = () => { const dispatch = useAppDispatch(); - const nodes = useAppSelector((state: RootState) => state.nodes.nodes); - const edges = useAppSelector((state: RootState) => state.nodes.edges); + const nodes = useAppSelector((state) => state.nodes.nodes); + const edges = useAppSelector((state) => state.nodes.edges); + const viewport = useAppSelector((state) => state.nodes.viewport); + const shouldSnapToGrid = useAppSelector( + (state) => state.nodes.shouldSnapToGrid + ); + + const isValidConnection = useIsValidConnection(); + + const [borderRadius] = useToken('radii', ['base']); const onNodesChange: OnNodesChange = useCallback( (changes) => { @@ -69,35 +85,66 @@ export const Flow = () => { dispatch(connectionEnded()); }, [dispatch]); - const onInit: OnInit = useCallback( - (v) => { - dispatch(setEditorInstance(v)); - if (v) v.fitView(); + const onEdgesDelete: OnEdgesDelete = useCallback( + (edges) => { + dispatch(edgesDeleted(edges)); + }, + [dispatch] + ); + + const onNodesDelete: OnNodesDelete = useCallback( + (nodes) => { + dispatch(nodesDeleted(nodes)); + }, + [dispatch] + ); + + const handleSelectionChange: OnSelectionChangeFunc = useCallback( + ({ nodes, edges }) => { + dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : [])); + dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : [])); + }, + [dispatch] + ); + + const handleMoveEnd: OnMoveEnd = useCallback( + (e, viewport) => { + dispatch(viewportChanged(viewport)); }, [dispatch] ); return ( - + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx deleted file mode 100644 index 7b56bc95b4..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react'; -import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation'; -import { memo } from 'react'; -import { FaInfoCircle } from 'react-icons/fa'; - -interface IAINodeHeaderProps { - nodeId?: string; - title?: string; - description?: string; -} - -const IAINodeHeader = (props: IAINodeHeaderProps) => { - const { nodeId, title, description } = props; - return ( - - - - {title} - - - - - - - ); -}; - -export default memo(IAINodeHeader); diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx deleted file mode 100644 index 6f779e4295..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx +++ /dev/null @@ -1,149 +0,0 @@ -import { - Box, - Divider, - Flex, - FormControl, - FormLabel, - HStack, - Tooltip, -} from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; -import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; -import { - InputFieldTemplate, - InputFieldValue, - InvocationTemplate, -} from 'features/nodes/types/types'; -import { map } from 'lodash-es'; -import { ReactNode, memo, useCallback } from 'react'; -import FieldHandle from '../FieldHandle'; -import InputFieldComponent from '../InputFieldComponent'; - -interface IAINodeInputProps { - nodeId: string; - - input: InputFieldValue; - template?: InputFieldTemplate | undefined; - connected: boolean; -} - -function IAINodeInput(props: IAINodeInputProps) { - const { nodeId, input, template, connected } = props; - const isValidConnection = useIsValidConnection(); - - return ( - - - {!template ? ( - - Unknown input: {input.name} - - ) : ( - <> - - - - {template?.title} - - - - - - {!['never', 'directOnly'].includes( - template?.inputRequirement ?? '' - ) && ( - - )} - - )} - - - ); -} - -interface IAINodeInputsProps { - nodeId: string; - template: InvocationTemplate; - inputs: Record; -} - -const IAINodeInputs = (props: IAINodeInputsProps) => { - const { nodeId, template, inputs } = props; - - const edges = useAppSelector((state: RootState) => state.nodes.edges); - - const renderIAINodeInputs = useCallback(() => { - const IAINodeInputsToRender: ReactNode[] = []; - const inputSockets = map(inputs); - - inputSockets.forEach((inputSocket, index) => { - const inputTemplate = template.inputs[inputSocket.name]; - - const isConnected = Boolean( - edges.filter((connectedInput) => { - return ( - connectedInput.target === nodeId && - connectedInput.targetHandle === inputSocket.name - ); - }).length - ); - - if (index < inputSockets.length) { - IAINodeInputsToRender.push( - - ); - } - - IAINodeInputsToRender.push( - - ); - }); - - return ( - - {IAINodeInputsToRender} - - ); - }, [edges, inputs, nodeId, template.inputs]); - - return renderIAINodeInputs(); -}; - -export default memo(IAINodeInputs); diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx deleted file mode 100644 index 2cb0bcde8d..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx +++ /dev/null @@ -1,97 +0,0 @@ -import { - InvocationTemplate, - OutputFieldTemplate, - OutputFieldValue, -} from 'features/nodes/types/types'; -import { memo, ReactNode, useCallback } from 'react'; -import { map } from 'lodash-es'; -import { useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react'; -import FieldHandle from '../FieldHandle'; -import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; - -interface IAINodeOutputProps { - nodeId: string; - output: OutputFieldValue; - template?: OutputFieldTemplate | undefined; - connected: boolean; -} - -function IAINodeOutput(props: IAINodeOutputProps) { - const { nodeId, output, template, connected } = props; - const isValidConnection = useIsValidConnection(); - - return ( - - - {!template ? ( - - - Unknown Output: {output.name} - - - ) : ( - <> - - {template?.title} - - - - )} - - - ); -} - -interface IAINodeOutputsProps { - nodeId: string; - template: InvocationTemplate; - outputs: Record; -} - -const IAINodeOutputs = (props: IAINodeOutputsProps) => { - const { nodeId, template, outputs } = props; - - const edges = useAppSelector((state: RootState) => state.nodes.edges); - - const renderIAINodeOutputs = useCallback(() => { - const IAINodeOutputsToRender: ReactNode[] = []; - const outputSockets = map(outputs); - - outputSockets.forEach((outputSocket) => { - const outputTemplate = template.outputs[outputSocket.name]; - - const isConnected = Boolean( - edges.filter((connectedInput) => { - return ( - connectedInput.source === nodeId && - connectedInput.sourceHandle === outputSocket.name - ); - }).length - ); - - IAINodeOutputsToRender.push( - - ); - }); - - return {IAINodeOutputsToRender}; - }, [edges, nodeId, outputs, template.outputs]); - - return renderIAINodeOutputs(); -}; - -export default memo(IAINodeOutputs); diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx deleted file mode 100644 index 0ecc43ef9c..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ /dev/null @@ -1,252 +0,0 @@ -import { Box } from '@chakra-ui/react'; -import { memo } from 'react'; -import { InputFieldTemplate, InputFieldValue } from '../types/types'; -import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent'; -import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; -import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; -import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; -import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; -import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; -import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent'; -import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; -import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; -import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; -import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; -import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; -import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; -import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; -import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; -import StringInputFieldComponent from './fields/StringInputFieldComponent'; -import UnetInputFieldComponent from './fields/UnetInputFieldComponent'; -import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; -import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; -import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent'; - -type InputFieldComponentProps = { - nodeId: string; - field: InputFieldValue; - template: InputFieldTemplate; -}; - -// build an individual input element based on the schema -const InputFieldComponent = (props: InputFieldComponentProps) => { - const { nodeId, field, template } = props; - const { type } = field; - - if (type === 'string' && template.type === 'string') { - return ( - - ); - } - - if (type === 'boolean' && template.type === 'boolean') { - return ( - - ); - } - - if ( - (type === 'integer' && template.type === 'integer') || - (type === 'float' && template.type === 'float') - ) { - return ( - - ); - } - - if (type === 'enum' && template.type === 'enum') { - return ( - - ); - } - - if (type === 'image' && template.type === 'image') { - return ( - - ); - } - - if (type === 'latents' && template.type === 'latents') { - return ( - - ); - } - - if (type === 'conditioning' && template.type === 'conditioning') { - return ( - - ); - } - - if (type === 'unet' && template.type === 'unet') { - return ( - - ); - } - - if (type === 'clip' && template.type === 'clip') { - return ( - - ); - } - - if (type === 'vae' && template.type === 'vae') { - return ( - - ); - } - - if (type === 'control' && template.type === 'control') { - return ( - - ); - } - - if (type === 'model' && template.type === 'model') { - return ( - - ); - } - - if (type === 'refiner_model' && template.type === 'refiner_model') { - return ( - - ); - } - - if (type === 'vae_model' && template.type === 'vae_model') { - return ( - - ); - } - - if (type === 'lora_model' && template.type === 'lora_model') { - return ( - - ); - } - - if (type === 'controlnet_model' && template.type === 'controlnet_model') { - return ( - - ); - } - - if (type === 'array' && template.type === 'array') { - return ( - - ); - } - - if (type === 'item' && template.type === 'item') { - return ( - - ); - } - - if (type === 'color' && template.type === 'color') { - return ( - - ); - } - - if (type === 'item' && template.type === 'item') { - return ( - - ); - } - - if (type === 'image_collection' && template.type === 'image_collection') { - return ( - - ); - } - - return Unknown field type: {type}; -}; - -export default memo(InputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx new file mode 100644 index 0000000000..a86b52060b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx @@ -0,0 +1,84 @@ +import { Flex } from '@chakra-ui/react'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { map, some } from 'lodash-es'; +import { memo, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; +import InputField from '../fields/InputField'; +import OutputField from '../fields/OutputField'; +import NodeFooter, { FOOTER_FIELDS } from './NodeFooter'; +import NodeHeader from './NodeHeader'; +import NodeWrapper from './NodeWrapper'; + +type Props = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +}; + +const InvocationNode = ({ nodeProps, nodeTemplate }: Props) => { + const { id: nodeId, data } = nodeProps; + const { inputs, outputs, isOpen } = data; + + const inputFields = useMemo( + () => map(inputs).filter((i) => i.name !== 'is_intermediate'), + [inputs] + ); + const outputFields = useMemo(() => map(outputs), [outputs]); + + const withFooter = useMemo( + () => some(outputs, (output) => FOOTER_FIELDS.includes(output.type)), + [outputs] + ); + + return ( + + + {isOpen && ( + <> + + + {outputFields.map((field) => ( + + ))} + {inputFields.map((field) => ( + + ))} + + + {withFooter && ( + + )} + + )} + + ); +}; + +export default memo(InvocationNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx new file mode 100644 index 0000000000..d67ca10dcc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx @@ -0,0 +1,57 @@ +import { ChevronUpIcon } from '@chakra-ui/icons'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice'; +import { NodeData } from 'features/nodes/types/types'; +import { memo, useCallback } from 'react'; +import { NodeProps, useUpdateNodeInternals } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; +} + +const NodeCollapseButton = (props: Props) => { + const { id: nodeId, isOpen } = props.nodeProps.data; + const dispatch = useAppDispatch(); + const updateNodeInternals = useUpdateNodeInternals(); + + const handleClick = useCallback(() => { + dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen })); + updateNodeInternals(nodeId); + }, [dispatch, isOpen, nodeId, updateNodeInternals]); + + return ( + + } + /> + ); +}; + +export default memo(NodeCollapseButton); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx new file mode 100644 index 0000000000..ece24f6f8c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx @@ -0,0 +1,74 @@ +import { useColorModeValue } from '@chakra-ui/react'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { map } from 'lodash-es'; +import { CSSProperties, memo, useMemo } from 'react'; +import { Handle, NodeProps, Position } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +} + +const NodeCollapsedHandles = (props: Props) => { + const { data } = props.nodeProps; + const { base400, base600 } = useChakraThemeTokens(); + const backgroundColor = useColorModeValue(base400, base600); + + const dummyHandleStyles: CSSProperties = useMemo( + () => ({ + borderWidth: 0, + borderRadius: '3px', + width: '1rem', + height: '1rem', + backgroundColor, + zIndex: -1, + }), + [backgroundColor] + ); + + return ( + <> + + {map(data.inputs, (input) => ( + false} + position={Position.Left} + style={{ visibility: 'hidden' }} + /> + ))} + false} + isConnectable={false} + position={Position.Right} + style={{ ...dummyHandleStyles, right: '-0.5rem' }} + /> + {map(data.outputs, (output) => ( + false} + position={Position.Right} + style={{ visibility: 'hidden' }} + /> + ))} + + ); +}; + +export default memo(NodeCollapsedHandles); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx new file mode 100644 index 0000000000..38c2001b99 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx @@ -0,0 +1,80 @@ +import { + Checkbox, + Flex, + FormControl, + FormLabel, + Spacer, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { some } from 'lodash-es'; +import { ChangeEvent, memo, useCallback, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; + +export const IMAGE_FIELDS = ['ImageField', 'ImageCollection']; +export const FOOTER_FIELDS = IMAGE_FIELDS; + +type Props = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +}; + +const NodeFooter = (props: Props) => { + const { nodeProps, nodeTemplate } = props; + const dispatch = useAppDispatch(); + + const hasImageOutput = useMemo( + () => + some(nodeTemplate?.outputs, (output) => + IMAGE_FIELDS.includes(output.type) + ), + [nodeTemplate?.outputs] + ); + + const handleChangeIsIntermediate = useCallback( + (e: ChangeEvent) => { + dispatch( + fieldBooleanValueChanged({ + nodeId: nodeProps.data.id, + fieldName: 'is_intermediate', + value: !e.target.checked, + }) + ); + }, + [dispatch, nodeProps.data.id] + ); + + return ( + + + {hasImageOutput && ( + + Save Output + + + )} + + ); +}; + +export default memo(NodeFooter); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx new file mode 100644 index 0000000000..a946d21581 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx @@ -0,0 +1,54 @@ +import { Flex } from '@chakra-ui/react'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { memo } from 'react'; +import { NodeProps } from 'reactflow'; +import NodeCollapseButton from '../Invocation/NodeCollapseButton'; +import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles'; +import NodeNotesEdit from '../Invocation/NodeNotesEdit'; +import NodeStatusIndicator from '../Invocation/NodeStatusIndicator'; +import NodeTitle from '../Invocation/NodeTitle'; + +type Props = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +}; + +const NodeHeader = (props: Props) => { + const { nodeProps, nodeTemplate } = props; + const { isOpen } = nodeProps.data; + + return ( + + + + + + + + {!isOpen && ( + + )} + + ); +}; + +export default memo(NodeHeader); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx new file mode 100644 index 0000000000..ab54ca2c44 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx @@ -0,0 +1,113 @@ +import { + Flex, + FormControl, + FormLabel, + Icon, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Text, + Tooltip, + useDisclosure, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAITextarea from 'common/components/IAITextarea'; +import { nodeNotesChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { ChangeEvent, memo, useCallback } from 'react'; +import { FaInfoCircle } from 'react-icons/fa'; +import { NodeProps } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +} + +const NodeNotesEdit = (props: Props) => { + const { nodeProps, nodeTemplate } = props; + const { data } = nodeProps; + const { isOpen, onOpen, onClose } = useDisclosure(); + const dispatch = useAppDispatch(); + const handleNotesChanged = useCallback( + (e: ChangeEvent) => { + dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value })); + }, + [data.id, dispatch] + ); + + return ( + <> + + ) : undefined + } + placement="top" + shouldWrapChildren + > + + + + + + + + + + {data.label || nodeTemplate?.title || 'Unknown Node'} + + + + + Notes + + + + + + + + ); +}; + +export default memo(NodeNotesEdit); + +type TooltipContentProps = Props; + +const TooltipContent = (props: TooltipContentProps) => { + return ( + + {props.nodeTemplate?.title} + + {props.nodeTemplate?.description} + + {props.nodeProps.data.notes && {props.nodeProps.data.notes}} + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx similarity index 73% rename from invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx rename to invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx index 1aca32ec70..6391e86471 100644 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx @@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants'; import { memo } from 'react'; import { NodeResizeControl, NodeResizerProps } from 'reactflow'; -const IAINodeResizer = (props: NodeResizerProps) => { +// this causes https://github.com/invoke-ai/InvokeAI/issues/4140 +// not using it for now + +const NodeResizer = (props: NodeResizerProps) => { const { ...rest } = props; return ( { ); }; -export default memo(IAINodeResizer); +export default memo(NodeResizer); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx new file mode 100644 index 0000000000..bf12358871 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx @@ -0,0 +1,69 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAIPopover from 'common/components/IAIPopover'; +import IAISwitch from 'common/components/IAISwitch'; +import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; +import { InvocationNodeData } from 'features/nodes/types/types'; +import { ChangeEvent, memo, useCallback } from 'react'; +import { FaBars } from 'react-icons/fa'; + +interface Props { + data: InvocationNodeData; +} + +const NodeSettings = (props: Props) => { + const { data } = props; + const dispatch = useAppDispatch(); + + const handleChangeIsIntermediate = useCallback( + (e: ChangeEvent) => { + dispatch( + fieldBooleanValueChanged({ + nodeId: data.id, + fieldName: 'is_intermediate', + value: e.target.checked, + }) + ); + }, + [data.id, dispatch] + ); + + return ( + } + /> + } + > + + + + + ); +}; + +export default memo(NodeSettings); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx new file mode 100644 index 0000000000..6695c4fd3b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx @@ -0,0 +1,185 @@ +import { + Badge, + CircularProgress, + Flex, + Icon, + Image, + Text, + Tooltip, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + NodeExecutionState, + NodeStatus, +} from 'features/nodes/types/types'; +import { memo, useMemo } from 'react'; +import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa'; +import { NodeProps } from 'reactflow'; + +type Props = { + nodeProps: NodeProps; +}; + +const iconBoxSize = 3; +const circleStyles = { + circle: { + transitionProperty: 'none', + transitionDuration: '0s', + }, + '.chakra-progress__track': { stroke: 'transparent' }, +}; + +const NodeStatusIndicator = (props: Props) => { + const nodeId = props.nodeProps.data.id; + const selectNodeExecutionState = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => nodes.nodeExecutionStates[nodeId] + ), + [nodeId] + ); + + const nodeExecutionState = useAppSelector(selectNodeExecutionState); + + if (!nodeExecutionState) { + return null; + } + + return ( + } + placement="top" + > + + + + + ); +}; + +export default memo(NodeStatusIndicator); + +type TooltipLabelProps = { + nodeExecutionState: NodeExecutionState; +}; + +const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => { + const { status, progress, progressImage } = nodeExecutionState; + if (status === NodeStatus.PENDING) { + return Pending; + } + + if (status === NodeStatus.IN_PROGRESS) { + if (progressImage) { + return ( + + + {progress !== null && ( + + {Math.round(progress * 100)}% + + )} + + ); + } + + if (progress !== null) { + return In Progress ({Math.round(progress * 100)}%); + } + + return In Progress; + } + + if (status === NodeStatus.COMPLETED) { + return Completed; + } + + if (status === NodeStatus.FAILED) { + return nodeExecutionState.error; + } + + return null; +}; + +type StatusIconProps = { + nodeExecutionState: NodeExecutionState; +}; + +const StatusIcon = (props: StatusIconProps) => { + const { progress, status } = props.nodeExecutionState; + if (status === NodeStatus.PENDING) { + return ( + + ); + } + if (status === NodeStatus.IN_PROGRESS) { + return progress === null ? ( + + ) : ( + + ); + } + if (status === NodeStatus.COMPLETED) { + return ( + + ); + } + if (status === NodeStatus.FAILED) { + return ( + + ); + } + return null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx new file mode 100644 index 0000000000..fa6a8ea224 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx @@ -0,0 +1,123 @@ +import { + Box, + Editable, + EditableInput, + EditablePreview, + Flex, + useEditableControls, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { nodeLabelChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { NodeData } from 'features/nodes/types/types'; +import { MouseEvent, memo, useCallback, useEffect, useState } from 'react'; + +type Props = { + nodeData: NodeData; + title: string; +}; + +const NodeTitle = (props: Props) => { + const { title } = props; + const { id: nodeId, label } = props.nodeData; + const dispatch = useAppDispatch(); + const [localTitle, setLocalTitle] = useState(label || title); + + const handleSubmit = useCallback( + async (newTitle: string) => { + dispatch(nodeLabelChanged({ nodeId, label: newTitle })); + setLocalTitle(newTitle || title); + }, + [nodeId, dispatch, title] + ); + + const handleChange = useCallback((newTitle: string) => { + setLocalTitle(newTitle); + }, []); + + useEffect(() => { + // Another component may change the title; sync local title with global state + setLocalTitle(label || title); + }, [label, title]); + + return ( + + + + + + + + ); +}; + +export default memo(NodeTitle); + +function EditableControls() { + const { isEditing, getEditButtonProps } = useEditableControls(); + const handleDoubleClick = useCallback( + (e: MouseEvent) => { + const { onClick } = getEditButtonProps(); + if (!onClick) { + return; + } + onClick(e); + }, + [getEditButtonProps] + ); + + if (isEditing) { + return null; + } + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx new file mode 100644 index 0000000000..2f555d700a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx @@ -0,0 +1,96 @@ +import { + Box, + ChakraProps, + useColorModeValue, + useToken, +} from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { nodeClicked } from 'features/nodes/store/nodesSlice'; +import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react'; +import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants'; +import { NodeData } from 'features/nodes/types/types'; +import { NodeProps } from 'reactflow'; + +const useNodeSelect = (nodeId: string) => { + const dispatch = useAppDispatch(); + + const selectNode = useCallback( + (e: MouseEvent) => { + dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey })); + }, + [dispatch, nodeId] + ); + + return selectNode; +}; + +type NodeWrapperProps = PropsWithChildren & { + nodeProps: NodeProps; + width?: NonNullable['w']; +}; + +const NodeWrapper = (props: NodeWrapperProps) => { + const { width, children, nodeProps } = props; + const { data, selected } = nodeProps; + const nodeId = data.id; + + const [ + nodeSelectedOutlineLight, + nodeSelectedOutlineDark, + shadowsXl, + shadowsBase, + ] = useToken('shadows', [ + 'nodeSelectedOutline.light', + 'nodeSelectedOutline.dark', + 'shadows.xl', + 'shadows.base', + ]); + + const selectNode = useNodeSelect(nodeId); + + const shadow = useColorModeValue( + nodeSelectedOutlineLight, + nodeSelectedOutlineDark + ); + + const shift = useAppSelector((state) => state.hotkeys.shift); + const opacity = useAppSelector((state) => state.nodes.nodeOpacity); + const className = useMemo( + () => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'), + [shift] + ); + + return ( + + + {children} + + ); +}; + +export default NodeWrapper; diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx new file mode 100644 index 0000000000..a16c6960ec --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx @@ -0,0 +1,69 @@ +import { Box, Flex, Text } from '@chakra-ui/react'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { InvocationNodeData } from 'features/nodes/types/types'; +import { memo } from 'react'; +import { NodeProps } from 'reactflow'; +import NodeCollapseButton from '../Invocation/NodeCollapseButton'; +import NodeWrapper from '../Invocation/NodeWrapper'; + +type Props = { + nodeProps: NodeProps; +}; + +const UnknownNodeFallback = ({ nodeProps }: Props) => { + const { data } = nodeProps; + const { isOpen, label, type } = data; + return ( + + + + + {label ? `${label} (${type})` : type} + + + {isOpen && ( + + + Unknown node type: + + {type} + + + + )} + + ); +}; + +export default memo(UnknownNodeFallback); diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx deleted file mode 100644 index 4c031afaff..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx +++ /dev/null @@ -1,74 +0,0 @@ -import { Flex, Icon } from '@chakra-ui/react'; -import { FaExclamationCircle } from 'react-icons/fa'; -import { NodeProps } from 'reactflow'; -import { InvocationValue } from '../types/types'; - -import { useAppSelector } from 'app/store/storeHooks'; -import { memo, useMemo } from 'react'; -import { makeTemplateSelector } from '../store/util/makeTemplateSelector'; -import IAINodeHeader from './IAINode/IAINodeHeader'; -import IAINodeInputs from './IAINode/IAINodeInputs'; -import IAINodeOutputs from './IAINode/IAINodeOutputs'; -import IAINodeResizer from './IAINode/IAINodeResizer'; -import NodeWrapper from './NodeWrapper'; - -export const InvocationComponent = memo((props: NodeProps) => { - const { id: nodeId, data, selected } = props; - const { type, inputs, outputs } = data; - - const templateSelector = useMemo(() => makeTemplateSelector(type), [type]); - - const template = useAppSelector(templateSelector); - - if (!template) { - return ( - - - - - - - ); - } - - return ( - - - - - - - - - ); -}); - -InvocationComponent.displayName = 'InvocationComponent'; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx index 8c0480774c..6920a2053b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx @@ -1,25 +1,109 @@ -import { Box } from '@chakra-ui/react'; -import { ReactFlowProvider } from 'reactflow'; +import { Flex } from '@chakra-ui/react'; +import { useAppSelector } from 'app/store/storeHooks'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import ResizeHandle from 'features/ui/components/tabs/ResizeHandle'; +import { memo, useState } from 'react'; +import { MdDeviceHub } from 'react-icons/md'; +import { Panel, PanelGroup } from 'react-resizable-panels'; import 'reactflow/dist/style.css'; - -import { memo } from 'react'; +import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup'; import { Flow } from './Flow'; +import { AnimatePresence, motion } from 'framer-motion'; const NodeEditor = () => { + const [isPanelCollapsed, setIsPanelCollapsed] = useState(false); + const isReady = useAppSelector((state) => state.nodes.isReady); return ( - - - - - + + + + + + + + {isReady && ( + + + + )} + + + {!isReady && ( + + + + + + )} + + + + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx new file mode 100644 index 0000000000..58e2e3564e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx @@ -0,0 +1,139 @@ +import { + Divider, + Flex, + Heading, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalHeader, + ModalOverlay, + useDisclosure, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISwitch from 'common/components/IAISwitch'; +import { ChangeEvent, useCallback } from 'react'; +import { FaCog } from 'react-icons/fa'; +import { + shouldAnimateEdgesChanged, + shouldColorEdgesChanged, + shouldSnapToGridChanged, + shouldValidateGraphChanged, +} from '../store/nodesSlice'; + +const selector = createSelector(stateSelector, ({ nodes }) => { + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + } = nodes; + return { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + }; +}); + +const NodeEditorSettings = () => { + const { isOpen, onOpen, onClose } = useDisclosure(); + const dispatch = useAppDispatch(); + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + } = useAppSelector(selector); + + const handleChangeShouldValidate = useCallback( + (e: ChangeEvent) => { + dispatch(shouldValidateGraphChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldAnimate = useCallback( + (e: ChangeEvent) => { + dispatch(shouldAnimateEdgesChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldSnap = useCallback( + (e: ChangeEvent) => { + dispatch(shouldSnapToGridChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldColor = useCallback( + (e: ChangeEvent) => { + dispatch(shouldColorEdgesChanged(e.target.checked)); + }, + [dispatch] + ); + + return ( + <> + } + onClick={onOpen} + /> + + + + + Node Editor Settings + + + + General + + + + + + + Advanced + + + + + + + + ); +}; + +export default NodeEditorSettings; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx index 1d498f19f5..4525dc5f6b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx @@ -1,34 +1,26 @@ -import { Box } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { memo } from 'react'; +import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON'; +import { omit } from 'lodash-es'; +import { useMemo } from 'react'; +import { useDebounce } from 'use-debounce'; import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph'; -const NodeGraphOverlay = () => { - const state = useAppSelector((state: RootState) => state); - const graph = buildNodesGraph(state); - - return ( - - {JSON.stringify(graph, null, 2)} - +const useNodesGraph = () => { + const nodes = useAppSelector((state: RootState) => state.nodes); + const [debouncedNodes] = useDebounce(nodes, 300); + const graph = useMemo( + () => omit(buildNodesGraph(debouncedNodes), 'id'), + [debouncedNodes] ); + + return graph; }; -export default memo(NodeGraphOverlay); +const NodeGraph = () => { + const graph = useNodesGraph(); + + return ; +}; + +export default NodeGraph; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx new file mode 100644 index 0000000000..693940859f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx @@ -0,0 +1,42 @@ +import { + Box, + Slider, + SliderFilledTrack, + SliderThumb, + SliderTrack, +} from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useCallback } from 'react'; +import { nodeOpacityChanged } from '../store/nodesSlice'; + +export default function NodeOpacitySlider() { + const dispatch = useAppDispatch(); + const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity); + + const handleChange = useCallback( + (v: number) => { + dispatch(nodeOpacityChanged(v)); + }, + [dispatch] + ); + + return ( + + + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx deleted file mode 100644 index bc7944a28b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { Box, useToken } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; -import { PropsWithChildren } from 'react'; -import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation'; -import { NODE_MIN_WIDTH } from '../types/constants'; - -type NodeWrapperProps = PropsWithChildren & { - selected: boolean; -}; - -const NodeWrapper = (props: NodeWrapperProps) => { - const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [ - 'nodeSelectedOutline', - 'dark-lg', - ]); - - const shift = useAppSelector((state) => state.hotkeys.shift); - - return ( - - {props.children} - - ); -}; - -export default NodeWrapper; diff --git a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx deleted file mode 100644 index 142e2a2990..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx +++ /dev/null @@ -1,73 +0,0 @@ -import { Flex, Image } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { memo } from 'react'; -import { useDispatch, useSelector } from 'react-redux'; -import { NodeProps, OnResize } from 'reactflow'; -import { setProgressNodeSize } from '../store/nodesSlice'; -import IAINodeHeader from './IAINode/IAINodeHeader'; -import IAINodeResizer from './IAINode/IAINodeResizer'; -import NodeWrapper from './NodeWrapper'; - -const ProgressImageNode = (props: NodeProps) => { - const progressImage = useSelector( - (state: RootState) => state.system.progressImage - ); - const progressNodeSize = useSelector( - (state: RootState) => state.nodes.progressNodeSize - ); - const dispatch = useDispatch(); - const { selected } = props; - - const handleResize: OnResize = (_, newSize) => { - dispatch(setProgressNodeSize(newSize)); - }; - - return ( - - - - {progressImage ? ( - - ) : ( - - - - )} - - - - ); -}; - -export default memo(ProgressImageNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx index 796cdb010e..7416c6c555 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx @@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { memo, useCallback } from 'react'; -import { - FaCode, - FaExpand, - FaMinus, - FaPlus, - FaInfo, - FaMapMarkerAlt, -} from 'react-icons/fa'; -import { useReactFlow } from 'reactflow'; import { useTranslation } from 'react-i18next'; import { - shouldShowGraphOverlayChanged, + FaExpand, + FaInfo, + FaMapMarkerAlt, + FaMinus, + FaPlus, +} from 'react-icons/fa'; +import { useReactFlow } from 'reactflow'; +import { shouldShowFieldTypeLegendChanged, shouldShowMinimapPanelChanged, } from '../store/nodesSlice'; @@ -22,9 +20,6 @@ const ViewportControls = () => { const { t } = useTranslation(); const { zoomIn, zoomOut, fitView } = useReactFlow(); const dispatch = useAppDispatch(); - const shouldShowGraphOverlay = useAppSelector( - (state) => state.nodes.shouldShowGraphOverlay - ); const shouldShowFieldTypeLegend = useAppSelector( (state) => state.nodes.shouldShowFieldTypeLegend ); @@ -44,10 +39,6 @@ const ViewportControls = () => { fitView(); }, [fitView]); - const handleClickedToggleGraphOverlay = useCallback(() => { - dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay)); - }, [shouldShowGraphOverlay, dispatch]); - const handleClickedToggleFieldTypeLegend = useCallback(() => { dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend)); }, [shouldShowFieldTypeLegend, dispatch]); @@ -79,20 +70,6 @@ const ViewportControls = () => { icon={} /> - - } - /> - ( - + + + + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx similarity index 91% rename from invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx index 39142ed48e..8b7fb942a6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx @@ -20,7 +20,7 @@ const MinimapPanel = () => { const nodeColor = useColorModeValue( 'var(--invokeai-colors-accent-300)', - 'var(--invokeai-colors-accent-700)' + 'var(--invokeai-colors-accent-600)' ); const maskColor = useColorModeValue( @@ -32,10 +32,9 @@ const MinimapPanel = () => { <> {shouldShowMinimapPanel && ( { return ( @@ -15,9 +14,8 @@ const TopCenterPanel = () => { - - + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx similarity index 100% rename from invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx similarity index 55% rename from invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx index e3e3a871c8..7facf3973f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx @@ -1,22 +1,16 @@ -import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; import { Panel } from 'reactflow'; import FieldTypeLegend from '../FieldTypeLegend'; -import NodeGraphOverlay from '../NodeGraphOverlay'; const TopRightPanel = () => { - const shouldShowGraphOverlay = useAppSelector( - (state: RootState) => state.nodes.shouldShowGraphOverlay - ); const shouldShowFieldTypeLegend = useAppSelector( - (state: RootState) => state.nodes.shouldShowFieldTypeLegend + (state) => state.nodes.shouldShowFieldTypeLegend ); return ( {shouldShowFieldTypeLegend && } - {shouldShowGraphOverlay && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx deleted file mode 100644 index 8e478c907c..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { - ArrayInputFieldTemplate, - ArrayInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FaList } from 'react-icons/fa'; -import { FieldComponentProps } from './types'; - -const ArrayInputFieldComponent = ( - _props: FieldComponentProps -) => { - return ; -}; - -export default memo(ArrayInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx deleted file mode 100644 index 5f26bc4f2a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import { Select } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; -import { - EnumInputFieldTemplate, - EnumInputFieldValue, -} from 'features/nodes/types/types'; -import { ChangeEvent, memo } from 'react'; -import { FieldComponentProps } from './types'; - -const EnumInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field, template } = props; - - const dispatch = useAppDispatch(); - - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; - - return ( - - ); -}; - -export default memo(EnumInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx new file mode 100644 index 0000000000..d9f8f951bc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx @@ -0,0 +1,47 @@ +import { MenuItem, MenuList } from '@chakra-ui/react'; +import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; +import { + InputFieldTemplate, + InputFieldValue, +} from 'features/nodes/types/types'; +import { MouseEvent, useCallback } from 'react'; +import { menuListMotionProps } from 'theme/components/menu'; + +type Props = { + nodeId: string; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + children: ContextMenuProps['children']; +}; + +const FieldContextMenu = (props: Props) => { + const skipEvent = useCallback((e: MouseEvent) => { + e.preventDefault(); + }, []); + + return ( + + menuProps={{ + size: 'sm', + isLazy: true, + }} + menuButtonProps={{ + bg: 'transparent', + _hover: { bg: 'transparent' }, + }} + renderMenu={() => ( + + Test + + )} + > + {props.children} + + ); +}; + +export default FieldContextMenu; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx new file mode 100644 index 0000000000..f47e68976d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx @@ -0,0 +1,122 @@ +import { Tooltip } from '@chakra-ui/react'; +import { CSSProperties, memo, useMemo } from 'react'; +import { Handle, HandleType, NodeProps, Position } from 'reactflow'; +import { + FIELDS, + HANDLE_TOOLTIP_OPEN_DELAY, + colorTokenToCssVar, +} from '../../types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, + OutputFieldTemplate, + OutputFieldValue, +} from '../../types/types'; + +export const handleBaseStyles: CSSProperties = { + position: 'absolute', + width: '1rem', + height: '1rem', + borderWidth: 0, + zIndex: 1, +}; + +export const inputHandleStyles: CSSProperties = { + left: '-1rem', +}; + +export const outputHandleStyles: CSSProperties = { + right: '-0.5rem', +}; + +type FieldHandleProps = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: InputFieldValue | OutputFieldValue; + fieldTemplate: InputFieldTemplate | OutputFieldTemplate; + handleType: HandleType; + isConnectionInProgress: boolean; + isConnectionStartField: boolean; + connectionError: string | null; +}; + +const FieldHandle = (props: FieldHandleProps) => { + const { + fieldTemplate, + handleType, + isConnectionInProgress, + isConnectionStartField, + connectionError, + } = props; + const { name, type } = fieldTemplate; + const { color, title } = FIELDS[type]; + + const styles: CSSProperties = useMemo(() => { + const s: CSSProperties = { + backgroundColor: colorTokenToCssVar(color), + position: 'absolute', + width: '1rem', + height: '1rem', + borderWidth: 0, + zIndex: 1, + }; + + if (handleType === 'target') { + s.insetInlineStart = '-1rem'; + } else { + s.insetInlineEnd = '-1rem'; + } + + if (isConnectionInProgress && !isConnectionStartField && connectionError) { + s.filter = 'opacity(0.4) grayscale(0.7)'; + } + + if (isConnectionInProgress && connectionError) { + if (isConnectionStartField) { + s.cursor = 'grab'; + } else { + s.cursor = 'not-allowed'; + } + } else { + s.cursor = 'crosshair'; + } + + return s; + }, [ + color, + connectionError, + handleType, + isConnectionInProgress, + isConnectionStartField, + ]); + + const tooltip = useMemo(() => { + if (isConnectionInProgress && isConnectionStartField) { + return title; + } + if (isConnectionInProgress && connectionError) { + return connectionError ?? title; + } + return title; + }, [connectionError, isConnectionInProgress, isConnectionStartField, title]); + + return ( + + + + ); +}; + +export default memo(FieldHandle); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx new file mode 100644 index 0000000000..fc239addf3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx @@ -0,0 +1,161 @@ +import { + Editable, + EditableInput, + EditablePreview, + Flex, + useEditableControls, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIDraggable from 'common/components/IAIDraggable'; +import { NodeFieldDraggableData } from 'features/dnd/types'; +import { fieldLabelChanged } from 'features/nodes/store/nodesSlice'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { + MouseEvent, + memo, + useCallback, + useEffect, + useMemo, + useState, +} from 'react'; + +interface Props { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + isDraggable?: boolean; +} + +const FieldTitle = (props: Props) => { + const { nodeData, field, fieldTemplate, isDraggable = false } = props; + const { label } = field; + const { title, input } = fieldTemplate; + const { id: nodeId } = nodeData; + const dispatch = useAppDispatch(); + const [localTitle, setLocalTitle] = useState(label || title); + + const draggableData: NodeFieldDraggableData | undefined = useMemo( + () => + input !== 'connection' && isDraggable + ? { + id: `${nodeId}-${field.name}`, + payloadType: 'NODE_FIELD', + payload: { nodeId, field, fieldTemplate }, + } + : undefined, + [field, fieldTemplate, input, isDraggable, nodeId] + ); + + const handleSubmit = useCallback( + async (newTitle: string) => { + dispatch( + fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle }) + ); + setLocalTitle(newTitle || title); + }, + [dispatch, nodeId, field.name, title] + ); + + const handleChange = useCallback((newTitle: string) => { + setLocalTitle(newTitle); + }, []); + + useEffect(() => { + // Another component may change the title; sync local title with global state + setLocalTitle(label || title); + }, [label, title]); + + return ( + + + + + + + + ); +}; + +export default memo(FieldTitle); + +type EditableControlsProps = { + draggableData?: NodeFieldDraggableData; +}; + +function EditableControls(props: EditableControlsProps) { + const { isEditing, getEditButtonProps } = useEditableControls(); + const handleDoubleClick = useCallback( + (e: MouseEvent) => { + const { onClick } = getEditButtonProps(); + if (!onClick) { + return; + } + onClick(e); + }, + [getEditButtonProps] + ); + + if (isEditing) { + return null; + } + + if (props.draggableData) { + return ( + + ); + } + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx new file mode 100644 index 0000000000..bf5cd3cd9b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx @@ -0,0 +1,41 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { FIELDS } from 'features/nodes/types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, + OutputFieldTemplate, + OutputFieldValue, + isInputFieldTemplate, + isInputFieldValue, +} from 'features/nodes/types/types'; +import { startCase } from 'lodash-es'; + +interface Props { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue | OutputFieldValue; + fieldTemplate: InputFieldTemplate | OutputFieldTemplate; +} + +const FieldTooltipContent = ({ field, fieldTemplate }: Props) => { + const isInputTemplate = isInputFieldTemplate(fieldTemplate); + + return ( + + + {isInputFieldValue(field) && field.label + ? `${field.label} (${fieldTemplate.title})` + : fieldTemplate.title} + + + {fieldTemplate.description} + + Type: {FIELDS[fieldTemplate.type].title} + {isInputTemplate && Input: {startCase(fieldTemplate.input)}} + + ); +}; + +export default FieldTooltipContent; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx new file mode 100644 index 0000000000..67f4369384 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx @@ -0,0 +1,153 @@ +import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; +import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { PropsWithChildren, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; +import FieldHandle from './FieldHandle'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; +} + +const InputField = (props: Props) => { + const { nodeProps, nodeTemplate, field } = props; + const { id: nodeId } = nodeProps.data; + + const { + isConnected, + isConnectionInProgress, + isConnectionStartField, + connectionError, + shouldDim, + } = useConnectionState({ nodeId, field, kind: 'input' }); + + const fieldTemplate = useMemo( + () => nodeTemplate.inputs[field.name], + [field.name, nodeTemplate.inputs] + ); + + const isMissingInput = useMemo(() => { + if (!fieldTemplate) { + return false; + } + + if (!fieldTemplate.required) { + return false; + } + + if (!isConnected && fieldTemplate.input === 'connection') { + return true; + } + + if (!field.value && !isConnected && fieldTemplate.input === 'any') { + return true; + } + }, [fieldTemplate, isConnected, field.value]); + + if (!fieldTemplate) { + return ( + + + Unknown input: {field.name} + + + ); + } + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + + + + + + {fieldTemplate.input !== 'direct' && ( + + )} + + ); +}; + +export default InputField; + +type InputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => ( + + {children} + +); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx new file mode 100644 index 0000000000..0eae336a1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx @@ -0,0 +1,292 @@ +import { Box } from '@chakra-ui/react'; +import { memo } from 'react'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from '../../types/types'; +import BooleanInputField from './fieldTypes/BooleanInputField'; +import ClipInputField from './fieldTypes/ClipInputField'; +import CollectionInputField from './fieldTypes/CollectionInputField'; +import CollectionItemInputField from './fieldTypes/CollectionItemInputField'; +import ColorInputField from './fieldTypes/ColorInputField'; +import ConditioningInputField from './fieldTypes/ConditioningInputField'; +import ControlInputField from './fieldTypes/ControlInputField'; +import ControlNetModelInputField from './fieldTypes/ControlNetModelInputField'; +import EnumInputField from './fieldTypes/EnumInputField'; +import ImageCollectionInputField from './fieldTypes/ImageCollectionInputField'; +import ImageInputField from './fieldTypes/ImageInputField'; +import LatentsInputField from './fieldTypes/LatentsInputField'; +import LoRAModelInputField from './fieldTypes/LoRAModelInputField'; +import MainModelInputField from './fieldTypes/MainModelInputField'; +import NumberInputField from './fieldTypes/NumberInputField'; +import RefinerModelInputField from './fieldTypes/RefinerModelInputField'; +import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField'; +import StringInputField from './fieldTypes/StringInputField'; +import UnetInputField from './fieldTypes/UnetInputField'; +import VaeInputField from './fieldTypes/VaeInputField'; +import VaeModelInputField from './fieldTypes/VaeModelInputField'; + +type InputFieldProps = { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; +}; + +// build an individual input element based on the schema +const InputFieldRenderer = (props: InputFieldProps) => { + const { nodeData, nodeTemplate, field, fieldTemplate } = props; + const { type } = field; + + if (type === 'string' && fieldTemplate.type === 'string') { + return ( + + ); + } + + if (type === 'boolean' && fieldTemplate.type === 'boolean') { + return ( + + ); + } + + if ( + (type === 'integer' && fieldTemplate.type === 'integer') || + (type === 'float' && fieldTemplate.type === 'float') + ) { + return ( + + ); + } + + if (type === 'enum' && fieldTemplate.type === 'enum') { + return ( + + ); + } + + if (type === 'ImageField' && fieldTemplate.type === 'ImageField') { + return ( + + ); + } + + if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') { + return ( + + ); + } + + if ( + type === 'ConditioningField' && + fieldTemplate.type === 'ConditioningField' + ) { + return ( + + ); + } + + if (type === 'UNetField' && fieldTemplate.type === 'UNetField') { + return ( + + ); + } + + if (type === 'ClipField' && fieldTemplate.type === 'ClipField') { + return ( + + ); + } + + if (type === 'VaeField' && fieldTemplate.type === 'VaeField') { + return ( + + ); + } + + if (type === 'ControlField' && fieldTemplate.type === 'ControlField') { + return ( + + ); + } + + if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') { + return ( + + ); + } + + if ( + type === 'SDXLRefinerModelField' && + fieldTemplate.type === 'SDXLRefinerModelField' + ) { + return ( + + ); + } + + if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') { + return ( + + ); + } + + if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') { + return ( + + ); + } + + if ( + type === 'ControlNetModelField' && + fieldTemplate.type === 'ControlNetModelField' + ) { + return ( + + ); + } + + if (type === 'Collection' && fieldTemplate.type === 'Collection') { + return ( + + ); + } + + if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') { + return ( + + ); + } + + if (type === 'ColorField' && fieldTemplate.type === 'ColorField') { + return ( + + ); + } + + if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') { + return ( + + ); + } + + if ( + type === 'SDXLMainModelField' && + fieldTemplate.type === 'SDXLMainModelField' + ) { + return ( + + ); + } + + return Unknown field type: {type}; +}; + +export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx deleted file mode 100644 index 6fa89345bf..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { - ItemInputFieldTemplate, - ItemInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FaAddressCard } from 'react-icons/fa'; -import { FieldComponentProps } from './types'; - -const ItemInputFieldComponent = ( - _props: FieldComponentProps -) => { - return ; -}; - -export default memo(ItemInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx new file mode 100644 index 0000000000..98a8000b1a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx @@ -0,0 +1,88 @@ +import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { memo } from 'react'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +type Props = { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; +}; + +const LinearViewField = ({ + nodeData, + nodeTemplate, + field, + fieldTemplate, +}: Props) => { + // const dispatch = useAppDispatch(); + // const handleRemoveField = useCallback(() => { + // dispatch( + // workflowExposedFieldRemoved({ + // nodeId: nodeData.id, + // fieldName: field.name, + // }) + // ); + // }, [dispatch, field.name, nodeData.id]); + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + + + + + + ); +}; + +export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx new file mode 100644 index 0000000000..5a29d1ab7e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx @@ -0,0 +1,114 @@ +import { + Flex, + FormControl, + FormLabel, + Spacer, + Tooltip, +} from '@chakra-ui/react'; +import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, + OutputFieldValue, +} from 'features/nodes/types/types'; +import { PropsWithChildren, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; +import FieldHandle from './FieldHandle'; +import FieldTooltipContent from './FieldTooltipContent'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: OutputFieldValue; +} + +const OutputField = (props: Props) => { + const { nodeTemplate, nodeProps, field } = props; + + const { + isConnected, + isConnectionInProgress, + isConnectionStartField, + connectionError, + shouldDim, + } = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' }); + + const fieldTemplate = useMemo( + () => nodeTemplate.outputs[field.name], + [field.name, nodeTemplate] + ); + + if (!fieldTemplate) { + return ( + + + Unknown output: {field.name} + + + ); + } + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + {fieldTemplate?.title} + + + + + + ); +}; + +export default OutputField; + +type OutputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +const OutputFieldWrapper = ({ + shouldDim, + children, +}: OutputFieldWrapperProps) => ( + + {children} + +); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx deleted file mode 100644 index 18cf7e997f..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { Input, Textarea } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; -import { - StringInputFieldTemplate, - StringInputFieldValue, -} from 'features/nodes/types/types'; -import { ChangeEvent, memo } from 'react'; -import { FieldComponentProps } from './types'; - -const StringInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field } = props; - const dispatch = useAppDispatch(); - - const handleValueChanged = ( - e: ChangeEvent - ) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; - - return ['prompt', 'style'].includes(field.name.toLowerCase()) ? ( -