Merge branch 'main' into refactor/rename-get-logger

This commit is contained in:
Lincoln Stein 2023-08-16 09:19:52 -04:00 committed by GitHub
commit fc9b4539a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
246 changed files with 11233 additions and 7616 deletions

View File

@ -5,7 +5,7 @@ from PIL import Image
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel from pydantic import BaseModel, Field
from invokeai.app.invocations.metadata import ImageMetadata from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
@ -19,6 +19,7 @@ from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"]) images_router = APIRouter(prefix="/v1/images", tags=["images"])
# images are immutable; set a high max-age # images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000 IMAGE_MAX_AGE = 31536000
@ -286,3 +287,41 @@ async def delete_images_from_list(
return DeleteImagesFromListResult(deleted_images=deleted_images) return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Failed to delete images") raise HTTPException(status_code=500, detail="Failed to delete images")
class ImagesUpdatedFromListResult(BaseModel):
updated_image_names: list[str] = Field(description="The image names that were updated")
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
async def star_images_in_list(
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> ImagesUpdatedFromListResult:
try:
updated_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
except:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
async def unstar_images_in_list(
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> ImagesUpdatedFromListResult:
try:
updated_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
except:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to unstar images")

View File

@ -38,7 +38,7 @@ import mimetypes
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch import torch
@ -134,6 +134,11 @@ def custom_openapi():
# This could break in some cases, figure out a better way to do it # This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"] output_type_titles[schema_key] = output_schema["title"]
# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, output_schema in ui_config_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
# Add a reference to the output type to additionalProperties of the invoker schema # Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations: for invoker in all_invocations:
invoker_name = invoker.__name__ invoker_name = invoker.__name__

View File

@ -3,15 +3,366 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Mapping,
Optional,
Type,
TypeVar,
Union,
get_args,
get_type_hints,
)
from pydantic import BaseConfig, BaseModel, Field from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
core_metadata = "Optional core metadata to be written to image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
class Input(str, Enum):
"""
The type of input a field accepts.
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
are instantiated.
- `Input.Connection`: The field must have its value provided by a connection.
- `Input.Any`: The field may have its value provided either directly or by a connection.
"""
Connection = "connection"
Direct = "direct"
Any = "any"
class UIType(str, Enum):
"""
Type hints for the UI.
If a field should be provided a data type that does not exactly match the python type of the field, \
use this to provide the type that should be used instead. See the node development docs for detail \
on adding a new field type, which involves client-side changes.
"""
# region Primitives
Integer = "integer"
Float = "float"
Boolean = "boolean"
String = "string"
Array = "array"
Image = "ImageField"
Latents = "LatentsField"
Conditioning = "ConditioningField"
Control = "ControlField"
Color = "ColorField"
ImageCollection = "ImageCollection"
ConditioningCollection = "ConditioningCollection"
ColorCollection = "ColorCollection"
LatentsCollection = "LatentsCollection"
IntegerCollection = "IntegerCollection"
FloatCollection = "FloatCollection"
StringCollection = "StringCollection"
BooleanCollection = "BooleanCollection"
# endregion
# region Models
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
# endregion
# region Iterate/Collect
Collection = "Collection"
CollectionItem = "CollectionItem"
# endregion
# region Misc
FilePath = "FilePath"
Enum = "enum"
# endregion
class UIComponent(str, Enum):
"""
The type of UI component to use for a field, used to override the default components, which are \
inferred from the field type.
"""
None_ = "none"
Textarea = "textarea"
Slider = "slider"
class _InputField(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
input: Input
ui_hidden: bool
ui_type: Optional[UIType]
ui_component: Optional[UIComponent]
class _OutputField(BaseModel):
"""
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
ui_hidden: bool
ui_type: Optional[UIType]
def InputField(
*args: Any,
default: Any = Undefined,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
allow_inf_nan: Optional[bool] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
input: Input = Input.Any,
ui_type: Optional[UIType] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
**kwargs: Any,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param Input input: [Input.Any] The kind of input this field requires. \
`Input.Direct` means a value must be provided on instantiation. \
`Input.Connection` means the value must be provided by a connection. \
`Input.Any` means either will do.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
The UI will always render a suitable component, but sometimes you want something different than the default. \
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
For this case, you could provide `UIComponent.Textarea`.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
"""
return Field(
*args,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
**kwargs,
)
def OutputField(
*args: Any,
default: Any = Undefined,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
allow_inf_nan: Optional[bool] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
ui_type: Optional[UIType] = None,
ui_hidden: bool = False,
**kwargs: Any,
) -> Any:
"""
Creates an output field for an invocation output.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
In some situations, the field's type is not enough to infer the correct UI type. \
For example, model selection fields should render a dropdown UI component to select a model. \
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
"""
return Field(
*args,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
ui_type=ui_type,
ui_hidden=ui_hidden,
**kwargs,
)
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @tags and @title decorator logic. You probably want to use those
decorators, though you may add this class to a node definition to specify the title and tags.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
title: Optional[str] = Field(default=None, description="The display name of the node")
class InvocationContext: class InvocationContext:
services: InvocationServices services: InvocationServices
graph_execution_state_id: str graph_execution_state_id: str
@ -39,6 +390,20 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses) return tuple(subclasses)
class RequiredConnectionException(Exception):
"""Raised when an field which requires a connection did not receive a value."""
def __init__(self, node_id: str, field_name: str):
super().__init__(f"Node {node_id} missing connections for field {field_name}")
class MissingInputException(Exception):
"""Raised when an field which requires some input, but did not receive a value."""
def __init__(self, node_id: str, field_name: str):
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
class BaseInvocation(ABC, BaseModel): class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs. """A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers. May use dependency injection in __init__ to receive providers.
@ -76,70 +441,81 @@ class BaseInvocation(ABC, BaseModel):
def get_output_type(cls): def get_output_type(cls):
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
class Config:
@staticmethod
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
@abstractmethod @abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput: def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs.""" """Invoke with provided context and return outputs."""
pass pass
# fmt: off def __init__(self, **data):
id: str = Field(description="The id of this node. Must be unique among all nodes.") # nodes may have required fields, that can accept input from connections
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") # on instantiation of the model, we need to exclude these from validation
# fmt: on restore = dict()
try:
field_names = list(self.__fields__.keys())
for field_name in field_names:
# if the field is required and may get its value from a connection, exclude it from validation
field = self.__fields__[field_name]
_input = field.field_info.extra.get("input", None)
if _input in [Input.Connection, Input.Any] and field.required:
if field_name not in data:
restore[field_name] = self.__fields__.pop(field_name)
# instantiate the node, which will validate the data
super().__init__(**data)
finally:
# restore the removed fields
for field_name, field in restore.items():
self.__fields__[field_name] = field
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
for field_name, field in self.__fields__.items():
_input = field.field_info.extra.get("input", None)
if field.required and not hasattr(self, field_name):
if _input == Input.Connection:
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
)
UIConfig: ClassVar[Type[UIConfigBase]]
# TODO: figure out a better way to provide these hints T = TypeVar("T", bound=BaseInvocation)
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
type_hints: Dict[
str,
Literal[
"integer",
"float",
"boolean",
"string",
"enum",
"image",
"latents",
"model",
"control",
"image_collection",
"vae_model",
"lora_model",
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict): def title(title: str) -> Callable[[Type[T]], Type[T]]:
ui: UIConfig """Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.title = title
return cls
return wrapper
class InvocationConfig(BaseConfig): def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
"""Customizes pydantic's BaseModel.Config class for use by Invocations. """Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
Provide `schema_extra` a `ui` dict to add hints for generated UIs. def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.tags = list(tags)
return cls
`tags` return wrapper
- A list of strings, used to categorise invocations.
`type_hints`
- A dict of field types which override the types in the invocation definition.
- Each key should be the name of one of the invocation's fields.
- Each value should be one of the valid types:
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
```python
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"initial_image": "image",
},
},
}
```
"""
schema_extra: CustomisedSchemaExtra

View File

@ -3,58 +3,25 @@
from typing import Literal from typing import Literal
import numpy as np import numpy as np
from pydantic import Field, validator from pydantic import validator
from invokeai.app.models.image import ImageField from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
type: Literal["image_collection"] = "image_collection"
# Outputs
collection: list[ImageField] = Field(default=[], description="The output images")
class Config:
schema_extra = {"required": ["type", "collection"]}
@title("Integer Range")
@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range" type: Literal["range"] = "range"
# Inputs # Inputs
start: int = Field(default=0, description="The start of the range") start: int = InputField(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range") stop: int = InputField(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range") step: int = InputField(default=1, description="The step of the range")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
}
@validator("stop") @validator("stop")
def stop_gt_start(cls, v, values): def stop_gt_start(cls, v, values):
@ -62,76 +29,44 @@ class RangeInvocation(BaseInvocation):
raise ValueError("stop must be greater than start") raise ValueError("stop must be greater than start")
return v return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput: def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@tags("range", "integer", "size", "collection")
class RangeOfSizeInvocation(BaseInvocation): class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step""" """Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size" type: Literal["range_of_size"] = "range_of_size"
# Inputs # Inputs
start: int = Field(default=0, description="The start of the range") start: int = InputField(default=0, description="The start of the range")
size: int = Field(default=1, description="The number of values") size: int = InputField(default=1, description="The number of values")
step: int = Field(default=1, description="The step of the range") step: int = InputField(default=1, description="The step of the range")
class Config(InvocationConfig): def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
schema_extra = { return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
}
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@tags("range", "integer", "random", "collection")
class RandomRangeInvocation(BaseInvocation): class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers""" """Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range" type: Literal["random_range"] = "random_range"
# Inputs # Inputs
low: int = Field(default=0, description="The inclusive low value") low: int = InputField(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate") size: int = InputField(default=1, description="The number of values to generate")
seed: int = Field( seed: int = InputField(
ge=0, ge=0,
le=SEED_MAX, le=SEED_MAX,
description="The seed for the RNG (omit for random)", description="The seed for the RNG (omit for random)",
default_factory=get_random_seed, default_factory=get_random_seed,
) )
class Config(InvocationConfig): def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
schema_extra = {
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
}
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed) rng = np.random.default_rng(self.seed)
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
# fmt: off
type: Literal["image_collection"] = "image_collection"
# Inputs
images: list[ImageField] = Field(
default=[], description="The image collection to load"
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"type_hints": {
"title": "Image Collection",
"images": "image_collection",
}
},
}

View File

@ -1,32 +1,35 @@
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re import re
from dataclasses import dataclass
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from typing import List, Literal, Union
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from ...backend.util.devices import torch_dtype from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
BasicConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management import ModelPatcher, ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo from ...backend.model_management.models import ModelNotFoundException
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
tags,
title,
)
from .model import ClipField from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass @dataclass
@ -41,32 +44,26 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg" # PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput): @title("Compel Prompt")
"""Compel parser output""" @tags("prompt", "compel")
# fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
# fmt: on
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel" type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt") prompt: str = InputField(
clip: ClipField = Field(None, description="Clip to use") default="",
description=FieldDescriptions.compel_prompt,
# Schema customisation ui_component=UIComponent.Textarea,
class Config(InvocationConfig): )
schema_extra = { clip: ClipField = InputField(
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, title="CLIP",
} description=FieldDescriptions.clip,
input=Input.Connection,
)
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
context=context, context=context,
@ -149,7 +146,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data) context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput( return ConditioningOutput(
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
), ),
@ -270,30 +267,26 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec return c, c_pooled, ec
@title("SDXL Compel Prompt")
@tags("sdxl", "compel", "prompt")
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt" type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = Field(default="", description="Prompt") prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
style: str = Field(default="", description="Style prompt") style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
original_width: int = Field(1024, description="") original_width: int = InputField(default=1024, description="")
original_height: int = Field(1024, description="") original_height: int = InputField(default=1024, description="")
crop_top: int = Field(0, description="") crop_top: int = InputField(default=0, description="")
crop_left: int = Field(0, description="") crop_left: int = InputField(default=0, description="")
target_width: int = Field(1024, description="") target_width: int = InputField(default=1024, description="")
target_height: int = Field(1024, description="") target_height: int = InputField(default=1024, description="")
clip: ClipField = Field(None, description="Clip to use") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = Field(None, description="Clip2 to use") clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
}
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel( c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
) )
@ -326,38 +319,32 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data) context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput( return ConditioningOutput(
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
), ),
) )
@title("SDXL Refiner Compel Prompt")
@tags("sdxl", "compel", "prompt")
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt" type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ? style: str = InputField(
original_width: int = Field(1024, description="") default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
original_height: int = Field(1024, description="") ) # TODO: ?
crop_top: int = Field(0, description="") original_width: int = InputField(default=1024, description="")
crop_left: int = Field(0, description="") original_height: int = InputField(default=1024, description="")
aesthetic_score: float = Field(6.0, description="") crop_top: int = InputField(default=0, description="")
clip2: ClipField = Field(None, description="Clip to use") crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
# Schema customisation clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix # TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False) c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
@ -380,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data) context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput( return ConditioningOutput(
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
), ),
@ -391,21 +378,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output""" """Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output" type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = Field(None, description="Clip with skipped layers") clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@title("CLIP Skip")
@tags("clipskip", "clip", "skip")
class ClipSkipInvocation(BaseInvocation): class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip" type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = Field(None, description="Clip to use") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
}
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers

View File

@ -26,79 +26,31 @@ from controlnet_aux.util import HWC3, ade_palette
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType, ModelType from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import (
from ..models.image import ImageOutput, PILInvocationConfig BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
CONTROLNET_RESIZE_VALUES = Literal[ CONTROLNET_RESIZE_VALUES = Literal[
tuple(
[
"just_resize", "just_resize",
"crop_resize", "crop_resize",
"fill_resize", "fill_resize",
"just_resize_simple", "just_resize_simple",
]
)
] ]
@ -110,9 +62,8 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") control_model: ControlNetModelField = Field(description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -135,60 +86,39 @@ class ControlField(BaseModel):
raise ValueError("Control weights must be within -1 to 2 range") raise ValueError("Control weights must be within -1 to 2 range")
return v return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
},
}
class ControlOutput(BaseInvocationOutput): class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info")
# fmt: on # Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet")
@tags("controlnet")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { image: ImageField = InputField(description="The control image")
"ui": { control_model: ControlNetModelField = InputField(
"title": "ControlNet", default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
"tags": ["controlnet", "latents"], )
"type_hints": { control_weight: Union[float, List[float]] = InputField(
"model": "model", default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
"control": "control", )
# "cfg_scale": "float", begin_step_percent: float = InputField(
"cfg_scale": "number", default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
"control_weight": "float", )
}, end_step_percent: float = InputField(
}, default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
} )
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
@ -204,19 +134,13 @@ class ControlNetInvocation(BaseInvocation):
) )
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor" type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="The image to process")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { image: ImageField = InputField(description="The image to process")
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
}
def run_processor(self, image): def run_processor(self, image):
# superclass just passes through image without processing # superclass just passes through image without processing
@ -255,20 +179,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
) )
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Canny Processor")
@tags("controlnet", "canny")
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor" type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
# fmt: on
class Config(InvocationConfig): # Input
schema_extra = { low_threshold: int = InputField(
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]}, default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
} )
high_threshold: int = InputField(
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
)
def run_processor(self, image): def run_processor(self, image):
canny_processor = CannyDetector() canny_processor = CannyDetector()
@ -276,23 +200,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image return processed_image
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("HED (softedge) Processor")
@tags("controlnet", "hed", "softedge")
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor" type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]}, image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
} # safe not supported in controlnet_aux v0.0.3
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image): def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
@ -307,21 +227,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
return processed_image return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Lineart Processor")
@tags("controlnet", "lineart")
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image""" """Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor" type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
coarse: bool = Field(default=False, description="Whether to use coarse mode")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]}, image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
} coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image): def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
@ -331,23 +247,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
return processed_image return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Lineart Anime Processor")
@tags("controlnet", "lineart", "anime")
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": { image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
"title": "Lineart Anime Processor",
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
},
}
def run_processor(self, image): def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
@ -359,21 +268,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
return processed_image return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Openpose Processor")
@tags("controlnet", "openpose", "pose")
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor" type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]}, detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
} image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
@ -386,22 +291,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Midas (Depth) Processor")
@tags("controlnet", "midas", "depth")
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]}, bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
} # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image): def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
@ -415,20 +316,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
return processed_image return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Normal BAE Processor")
@tags("controlnet", "normal", "bae")
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor" type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]}, image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
}
def run_processor(self, image): def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
@ -438,22 +335,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
return processed_image return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("MLSD Processor")
@tags("controlnet", "mlsd")
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor" type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]}, image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
} thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image): def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
@ -467,22 +360,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("PIDI Processor")
@tags("controlnet", "pidi")
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor" type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
safe: bool = Field(default=False, description="Whether to use safe mode")
scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]}, image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
} safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image): def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
@ -496,26 +385,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
return processed_image return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Content Shuffle Processor")
@tags("controlnet", "contentshuffle")
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
"ui": { image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
"title": "Content Shuffle Processor", h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
"tags": ["controlnet", "contentshuffle", "image", "processor"], w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
}, f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
}
def run_processor(self, image): def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector() content_shuffle_processor = ContentShuffleDetector()
@ -531,17 +413,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Zoe (Depth) Processor")
@tags("controlnet", "zoe", "depth")
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
}
def run_processor(self, image): def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
@ -549,20 +426,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Mediapipe Face Processor")
@tags("controlnet", "mediapipe", "face")
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]}, min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
}
def run_processor(self, image): def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel # MediaPipeFaceDetector throws an error if image has alpha channel
@ -574,23 +447,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Leres (Depth) Processor")
@tags("controlnet", "leres", "depth")
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image""" """Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor" type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]}, thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
} boost: bool = InputField(default=False, description="Whether to use boost mode")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
@ -605,21 +474,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Tile Resample Processor")
# fmt: off @tags("controlnet", "tile")
type: Literal["tile_image_processor"] = "tile_image_processor" class TileResamplerProcessorInvocation(ImageProcessorInvocation):
# Inputs """Tile resampler processor"""
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
class Config(InvocationConfig): type: Literal["tile_image_processor"] = "tile_image_processor"
schema_extra = {
"ui": { # Inputs
"title": "Tile Resample Processor", # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
"tags": ["controlnet", "tile", "resample", "image", "processor"], down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
},
}
# tile_resample copied from sd-webui-controlnet/scripts/processor.py # tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample( def tile_resample(
@ -648,20 +512,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
return processed_image return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): @title("Segment Anything Processor")
@tags("controlnet", "segmentanything")
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor" type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Segment Anything Processor",
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
},
}
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")

View File

@ -5,40 +5,22 @@ from typing import Literal
import cv2 as cv import cv2 as cv
import numpy import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import BaseModel, Field from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .image import ImageOutput
class CvInvocationConfig(BaseModel): @title("OpenCV Inpaint")
"""Helper class to provide all OpenCV invocations with additional config""" @tags("opencv", "inpaint")
class CvInpaintInvocation(BaseInvocation):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["cv", "image"],
},
}
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint" type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to inpaint") image: ImageField = InputField(description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting") mask: ImageField = InputField(description="The mask to use when inpainting")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -1,60 +1,31 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional
import cv2 import cv2
import numpy import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
from pydantic import Field
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
# fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image: Optional[ImageField] = Field(
default=None, description="The image to load"
)
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Load Image", "tags": ["image", "load"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@title("Show Image")
@tags("image")
class ShowImageInvocation(BaseInvocation): class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline.""" """Displays a provided image, and passes it forward in the pipeline."""
# Metadata
type: Literal["show_image"] = "show_image" type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to show") image: ImageField = InputField(description="The image to show")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Show Image", "tags": ["image", "show"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -70,24 +41,20 @@ class ShowImageInvocation(BaseInvocation):
) )
class ImageCropInvocation(BaseInvocation, PILInvocationConfig): @title("Crop Image")
@tags("image", "crop")
class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
# fmt: off # Metadata
type: Literal["img_crop"] = "img_crop" type: Literal["img_crop"] = "img_crop"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to crop") image: ImageField = InputField(description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle") x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle") y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle") width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -111,24 +78,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): @title("Paste Image")
@tags("image", "paste")
class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image.""" """Pastes an image into another image."""
# fmt: off # Metadata
type: Literal["img_paste"] = "img_paste" type: Literal["img_paste"] = "img_paste"
# Inputs # Inputs
base_image: Optional[ImageField] = Field(default=None, description="The base image") base_image: ImageField = InputField(description="The base image")
image: Optional[ImageField] = Field(default=None, description="The image to paste") image: ImageField = InputField(description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") mask: Optional[ImageField] = InputField(
x: int = Field(default=0, description="The left x coordinate at which to paste the image") default=None,
y: int = Field(default=0, description="The top y coordinate at which to paste the image") description="The mask to use when pasting",
# fmt: on )
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
class Config(InvocationConfig): y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
schema_extra = {
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name) base_image = context.services.images.get_pil_image(self.base_image.image_name)
@ -164,23 +130,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
) )
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): @title("Mask from Alpha")
@tags("image", "mask")
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
# fmt: off # Metadata
type: Literal["tomask"] = "tomask" type: Literal["tomask"] = "tomask"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from") image: ImageField = InputField(description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask") invert: bool = InputField(default=False, description="Whether or not to invert the mask")
# fmt: on
class Config(InvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput:
schema_extra = {
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1] image_mask = image.split()[-1]
@ -196,28 +158,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
) )
return MaskOutput( return ImageOutput(
mask=ImageField(image_name=image_dto.image_name), image=ImageField(image_name=image_dto.image_name),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): @title("Multiply Images")
@tags("image", "multiply")
class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`.""" """Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off # Metadata
type: Literal["img_mul"] = "img_mul" type: Literal["img_mul"] = "img_mul"
# Inputs # Inputs
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply") image1: ImageField = InputField(description="The first image to multiply")
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply") image2: ImageField = InputField(description="The second image to multiply")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(self.image1.image_name) image1 = context.services.images.get_pil_image(self.image1.image_name)
@ -244,21 +202,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"] IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): @title("Extract Image Channel")
@tags("image", "channel")
class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image.""" """Gets a channel from an image."""
# fmt: off # Metadata
type: Literal["img_chan"] = "img_chan" type: Literal["img_chan"] = "img_chan"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from") image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -284,21 +238,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): @title("Convert Image Mode")
@tags("image", "convert")
class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode.""" """Converts an image to a different mode."""
# fmt: off # Metadata
type: Literal["img_conv"] = "img_conv" type: Literal["img_conv"] = "img_conv"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to convert") image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -321,22 +271,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): @title("Blur Image")
@tags("image", "blur")
class ImageBlurInvocation(BaseInvocation):
"""Blurs an image""" """Blurs an image"""
# fmt: off # Metadata
type: Literal["img_blur"] = "img_blur" type: Literal["img_blur"] = "img_blur"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to blur") image: ImageField = InputField(description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius") radius: float = InputField(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") # Metadata
# fmt: on blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -382,23 +329,19 @@ PIL_RESAMPLING_MAP = {
} }
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): @title("Resize Image")
@tags("image", "resize")
class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions""" """Resizes an image to specific dimensions"""
# fmt: off # Metadata
type: Literal["img_resize"] = "img_resize" type: Literal["img_resize"] = "img_resize"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to resize") image: ImageField = InputField(description="The image to resize")
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -426,22 +369,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): @title("Scale Image")
@tags("image", "scale")
class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor""" """Scales an image by a factor"""
# fmt: off # Metadata
type: Literal["img_scale"] = "img_scale" type: Literal["img_scale"] = "img_scale"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to scale") image: ImageField = InputField(description="The image to scale")
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image") scale_factor: float = InputField(
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") default=2.0,
# fmt: on gt=0,
description="The factor by which to scale the image",
class Config(InvocationConfig): )
schema_extra = { resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -471,22 +414,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): @title("Lerp Image")
@tags("image", "lerp")
class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
# fmt: off # Metadata
type: Literal["img_lerp"] = "img_lerp" type: Literal["img_lerp"] = "img_lerp"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value") min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -512,25 +451,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): @title("Inverse Lerp Image")
@tags("image", "ilerp")
class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
# fmt: off # Metadata
type: Literal["img_ilerp"] = "img_ilerp" type: Literal["img_ilerp"] = "img_ilerp"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value") min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Image Inverse Linear Interpolation",
"tags": ["image", "linear", "interpolation", "inverse"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -556,21 +488,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): @title("Blur NSFW Image")
@tags("image", "nsfw")
class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
# fmt: off # Metadata
type: Literal["img_nsfw"] = "img_nsfw" type: Literal["img_nsfw"] = "img_nsfw"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: ImageField = InputField(description="The image to check")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") metadata: Optional[CoreMetadata] = InputField(
# fmt: on default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -607,22 +537,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
return caution.resize((caution.width // 2, caution.height // 2)) return caution.resize((caution.width // 2, caution.height // 2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): @title("Add Invisible Watermark")
@tags("image", "watermark")
class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image""" """Add an invisible watermark to an image"""
# fmt: off # Metadata
type: Literal["img_watermark"] = "img_watermark" type: Literal["img_watermark"] = "img_watermark"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: ImageField = InputField(description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text") text: str = InputField(default="InvokeAI", description="Watermark text")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") metadata: Optional[CoreMetadata] = InputField(
# fmt: on default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -644,21 +572,23 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
) )
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig): @title("Mask Edge")
@tags("image", "mask", "inpaint")
class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image""" """Applies an edge mask to an image"""
# fmt: off
type: Literal["mask_edge"] = "mask_edge" type: Literal["mask_edge"] = "mask_edge"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to") image: ImageField = InputField(description="The image to apply the mask to")
edge_size: int = Field(description="The size of the edge") edge_size: int = InputField(description="The size of the edge")
edge_blur: int = Field(description="The amount of blur on the edge") edge_blur: int = InputField(description="The amount of blur on the edge")
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection") low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection") high_threshold: int = InputField(
# fmt: on description="Second threshold for the hysteresis procedure in Canny edge detection"
)
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.services.images.get_pil_image(self.image.image_name) mask = context.services.images.get_pil_image(self.image.image_name)
npimg = numpy.asarray(mask, dtype=numpy.uint8) npimg = numpy.asarray(mask, dtype=numpy.uint8)
@ -683,28 +613,23 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
) )
return MaskOutput( return ImageOutput(
mask=ImageField(image_name=image_dto.image_name), image=ImageField(image_name=image_dto.image_name),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): @title("Combine Mask")
@tags("image", "mask", "multiply")
class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["mask_combine"] = "mask_combine" type: Literal["mask_combine"] = "mask_combine"
# Inputs # Inputs
mask1: ImageField = Field(default=None, description="The first mask to combine") mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = Field(default=None, description="The second image to combine") mask2: ImageField = InputField(description="The second image to combine")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
@ -728,7 +653,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
) )
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): @title("Color Correct")
@tags("image", "color")
class ColorCorrectInvocation(BaseInvocation):
""" """
Shifts the colors of a target image to match the reference image, optionally Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image. using a mask to only color-correct certain regions of the target image.
@ -736,10 +663,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct" type: Literal["color_correct"] = "color_correct"
image: Optional[ImageField] = Field(default=None, description="The image to color-correct") # Inputs
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction") image: ImageField = InputField(description="The image to color-correct")
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction") reference: ImageField = InputField(description="Reference image for color-correction")
mask_blur_radius: float = Field(default=8, description="Mask blur radius") mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None pil_init_mask = None
@ -833,16 +761,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
) )
@title("Image Hue Adjustment")
@tags("image", "hue", "hsl")
class ImageHueAdjustmentInvocation(BaseInvocation): class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image.""" """Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust" type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name) pil_image = context.services.images.get_pil_image(self.image.image_name)
@ -877,16 +805,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
) )
@title("Image Luminosity Adjustment")
@tags("image", "luminosity", "hsl")
class ImageLuminosityAdjustmentInvocation(BaseInvocation): class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image.""" """Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust" type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)") luminosity: float = InputField(
# fmt: on default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name) pil_image = context.services.images.get_pil_image(self.image.image_name)
@ -925,16 +855,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
) )
@title("Image Saturation Adjustment")
@tags("image", "saturation", "hsl")
class ImageSaturationAdjustmentInvocation(BaseInvocation): class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image.""" """Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust" type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name) pil_image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
import numpy as np import numpy as np
import math import math
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import Field from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
BaseInvocation,
InvocationConfig,
InvocationContext,
)
def infill_methods() -> list[str]: def infill_methods() -> list[str]:
@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si return si
@title("Solid Color Infill")
@tags("image", "inpaint")
class InfillColorInvocation(BaseInvocation): class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba" type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: ColorField = Field( # Inputs
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255), default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill", description="The color to use to infill",
) )
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
) )
@title("Tile Infill")
@tags("image", "inpaint")
class InfillTileInvocation(BaseInvocation): class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image""" """Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile" type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill") # Input
tile_size: int = Field(default=32, ge=1, description="The tile size (px)") image: ImageField = InputField(description="The image to infill")
seed: int = Field( tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
ge=0, ge=0,
le=SEED_MAX, le=SEED_MAX,
description="The seed to use for tile generation (omit for random)", description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed, default_factory=get_random_seed,
) )
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
) )
@title("PatchMatch Infill")
@tags("image", "inpaint")
class InfillPatchMatchInvocation(BaseInvocation): class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm""" """Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch" type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill") # Inputs
image: ImageField = InputField(description="The image to infill")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -13,16 +13,25 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
ImageField,
ImageOutput,
LatentsField,
LatentsOutput,
build_latents_output,
)
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher from ...backend.model_management import BaseModelType, ModelPatcher
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
@ -32,48 +41,27 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
from .compel import ConditioningField from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
seed: Optional[int] = Field(description="Seed used to generate this latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
# fmt: off
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
# fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@ -111,30 +99,36 @@ def get_scheduler(
return scheduler return scheduler
@title("Denoise Latents")
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
type: Literal["denoise_latents"] = "denoise_latents" type: Literal["denoise_latents"] = "denoise_latents"
# Inputs # Inputs
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") positive_conditioning: ConditioningField = InputField(
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") description=FieldDescriptions.positive_cond, input=Input.Connection
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(
default=7.5,
ge=1,
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
) )
denoising_start: float = Field(default=0.0, ge=0, le=1, description="") negative_conditioning: ConditioningField = InputField(
denoising_end: float = Field(default=1.0, ge=0, le=1, description="") description=FieldDescriptions.negative_cond, input=Input.Connection
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use") )
unet: UNetField = Field(default=None, description="UNet submodel") noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") cfg_scale: Union[float, List[float]] = InputField(
mask: Optional[ImageField] = Field( default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
None, )
description="Mask", denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[ImageField] = InputField(
default=None,
description=FieldDescriptions.mask,
) )
@validator("cfg_scale") @validator("cfg_scale")
@ -149,20 +143,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Denoise Latents",
"tags": ["denoise", "latents"],
"type_hints": {
"model": "model",
"control": "control",
"cfg_scale": "number",
},
},
}
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, self,
@ -474,29 +454,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed) return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
# Latent to image @title("Latents to Image")
@tags("latents", "image", "vae")
class LatentsToImageInvocation(BaseInvocation): class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: LatentsField = InputField(
vae: VaeField = Field(default=None, description="Vae submodel") description=FieldDescriptions.latents,
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)") input=Input.Connection,
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") )
metadata: Optional[CoreMetadata] = Field( vae: VaeField = InputField(
default=None, description="Optional core metadata to be written to the image" description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
) )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Latents To Image",
"tags": ["latents", "image"],
},
}
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -574,24 +554,30 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@title("Resize Latents")
@tags("latents", "resize")
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: LatentsField = InputField(
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") description=FieldDescriptions.latents,
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") input=Input.Connection,
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
) )
width: int = InputField(
class Config(InvocationConfig): ge=64,
schema_extra = { multiple_of=8,
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]}, description=FieldDescriptions.width,
} )
height: int = InputField(
ge=64,
multiple_of=8,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -616,23 +602,21 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Scale Latents")
@tags("latents", "resize")
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor.""" """Scales latents by a given factor."""
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: LatentsField = InputField(
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") description=FieldDescriptions.latents,
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") input=Input.Connection,
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
) )
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
class Config(InvocationConfig): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
schema_extra = { antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
}
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -658,22 +642,23 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Image to Latents")
@tags("latents", "image", "vae")
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents.""" """Encodes an image into latents."""
type: Literal["i2l"] = "i2l" type: Literal["i2l"] = "i2l"
# Inputs # Inputs
image: Optional[ImageField] = Field(description="The image to encode") image: ImageField = InputField(
vae: VaeField = Field(default=None, description="Vae submodel") description="The image to encode",
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") )
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") vae: VaeField = InputField(
description=FieldDescriptions.vae,
# Schema customisation input=Input.Connection,
class Config(InvocationConfig): )
schema_extra = { tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]}, fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
}
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:

View File

@ -2,134 +2,83 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field
import numpy as np import numpy as np
from .baseinvocation import ( from invokeai.app.invocations.primitives import IntegerOutput
BaseInvocation,
BaseInvocationOutput, from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
InvocationContext,
InvocationConfig,
)
class MathInvocationConfig(BaseModel): @title("Add Integers")
"""Helper class to provide all math invocations with additional config""" @tags("math")
class AddInvocation(BaseInvocation):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["math"],
}
}
class IntOutput(BaseInvocationOutput):
"""An integer output"""
# fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
# fmt: on
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers""" """Adds two numbers"""
# fmt: off
type: Literal["add"] = "add" type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { a: int = InputField(default=0, description=FieldDescriptions.num_1)
"ui": {"title": "Add", "tags": ["math", "add"]}, b: int = InputField(default=0, description=FieldDescriptions.num_2)
}
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntOutput(a=self.a + self.b) return IntegerOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation, MathInvocationConfig): @title("Subtract Integers")
@tags("math")
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers""" """Subtracts two numbers"""
# fmt: off
type: Literal["sub"] = "sub" type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { a: int = InputField(default=0, description=FieldDescriptions.num_1)
"ui": {"title": "Subtract", "tags": ["math", "subtract"]}, b: int = InputField(default=0, description=FieldDescriptions.num_2)
}
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntOutput(a=self.a - self.b) return IntegerOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation, MathInvocationConfig): @title("Multiply Integers")
@tags("math")
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers""" """Multiplies two numbers"""
# fmt: off
type: Literal["mul"] = "mul" type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { a: int = InputField(default=0, description=FieldDescriptions.num_1)
"ui": {"title": "Multiply", "tags": ["math", "multiply"]}, b: int = InputField(default=0, description=FieldDescriptions.num_2)
}
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntOutput(a=self.a * self.b) return IntegerOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation, MathInvocationConfig): @title("Divide Integers")
@tags("math")
class DivideInvocation(BaseInvocation):
"""Divides two numbers""" """Divides two numbers"""
# fmt: off
type: Literal["div"] = "div" type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { a: int = InputField(default=0, description=FieldDescriptions.num_1)
"ui": {"title": "Divide", "tags": ["math", "divide"]}, b: int = InputField(default=0, description=FieldDescriptions.num_2)
}
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntOutput(a=int(self.a / self.b)) return IntegerOutput(a=int(self.a / self.b))
@title("Random Integer")
@tags("math")
class RandomIntInvocation(BaseInvocation): class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer.""" """Outputs a single random integer."""
# fmt: off
type: Literal["rand_int"] = "rand_int" type: Literal["rand_int"] = "rand_int"
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
# fmt: on
class Config(InvocationConfig): # Inputs
schema_extra = { low: int = InputField(default=0, description="The inclusive low value")
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]}, high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
}
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntOutput(a=np.random.randint(self.low, self.high)) return IntegerOutput(a=np.random.randint(self.low, self.high))

View File

@ -1,18 +1,22 @@
from typing import Literal, Optional, Union from typing import Literal, Optional
from pydantic import Field from pydantic import Field
from ...version import __version__
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
InvocationConfig, InputField,
InvocationContext, InvocationContext,
OutputField,
tags,
title,
) )
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
class LoRAMetadataField(BaseModelExcludeNull): class LoRAMetadataField(BaseModelExcludeNull):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""
@ -43,37 +47,37 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field( vae: Optional[VAEModelField] = Field(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# Latents-to-Latents # Latents-to-Latents
strength: Union[float, None] = Field( strength: Optional[float] = Field(
default=None, default=None,
description="The strength used for latents-to-latents", description="The strength used for latents-to-latents",
) )
init_image: Union[str, None] = Field(default=None, description="The name of the initial image") init_image: Optional[str] = Field(default=None, description="The name of the initial image")
# SDXL # SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner # SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field( refiner_cfg_scale: Optional[float] = Field(
default=None, default=None,
description="The classifier-free guidance scale parameter used for the refiner", description="The classifier-free guidance scale parameter used for the refiner",
) )
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Union[float, None] = Field( refiner_positive_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner" default=None, description="The aesthetic score used for the refiner"
) )
refiner_negative_aesthetic_store: Union[float, None] = Field( refiner_negative_aesthetic_store: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner" default=None, description="The aesthetic score used for the refiner"
) )
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModelExcludeNull): class ImageMetadata(BaseModelExcludeNull):
@ -91,69 +95,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output" type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = Field(description="The core metadata for the image") metadata: CoreMetadata = OutputField(description="The core metadata for the image")
@title("Metadata Accumulator")
@tags("metadata")
class MetadataAccumulatorInvocation(BaseInvocation): class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object""" """Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator" type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field( generation_mode: str = InputField(
description="The generation mode that output this image", description="The generation mode that output this image",
) )
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = InputField(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = InputField(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = InputField(description="The width parameter")
height: int = Field(description="The height parameter") height: int = InputField(description="The height parameter")
seed: int = Field(description="The seed used for noise generation") seed: int = InputField(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation") rand_device: str = InputField(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = InputField(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = InputField(description="The scheduler used for inference")
clip_skip: int = Field( clip_skip: int = InputField(
description="The number of skipped CLIP layers", description="The number of skipped CLIP layers",
) )
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference") controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Union[float, None] = Field( strength: Optional[float] = InputField(
default=None, default=None,
description="The strength used for latents-to-latents", description="The strength used for latents-to-latents",
) )
init_image: Union[str, None] = Field(default=None, description="The name of the initial image") init_image: Optional[str] = InputField(
vae: Union[VAEModelField, None] = Field( default=None,
description="The name of the initial image",
)
vae: Optional[VAEModelField] = InputField(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# SDXL # SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") positive_style_prompt: Optional[str] = InputField(
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") default=None,
description="The positive style prompt parameter",
)
negative_style_prompt: Optional[str] = InputField(
default=None,
description="The negative style prompt parameter",
)
# SDXL Refiner # SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") refiner_model: Optional[MainModelField] = InputField(
refiner_cfg_scale: Union[float, None] = Field( default=None,
description="The SDXL Refiner model used",
)
refiner_cfg_scale: Optional[float] = InputField(
default=None, default=None,
description="The classifier-free guidance scale parameter used for the refiner", description="The classifier-free guidance scale parameter used for the refiner",
) )
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") refiner_steps: Optional[int] = InputField(
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") default=None,
refiner_positive_aesthetic_score: Union[float, None] = Field( description="The number of steps used for the refiner",
default=None, description="The aesthetic score used for the refiner"
) )
refiner_negative_aesthetic_score: Union[float, None] = Field( refiner_scheduler: Optional[str] = InputField(
default=None, description="The aesthetic score used for the refiner" default=None,
description="The scheduler used for the refiner",
)
refiner_positive_aesthetic_store: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_negative_aesthetic_store: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_start: Optional[float] = InputField(
default=None,
description="The start value used for refiner denoising",
) )
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"],
},
}
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""

View File

@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
@ -39,13 +50,11 @@ class VaeField(BaseModel):
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
# fmt: off
type: Literal["model_loader_output"] = "model_loader_output" type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
# fmt: on
class MainModelField(BaseModel): class MainModelField(BaseModel):
@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
@title("Main Model Loader")
@tags("model")
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader" type: Literal["main_model_loader"] = "main_model_loader"
model: MainModelField = Field(description="The model to load") # Inputs
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision? # TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
loras=[], loras=[],
skipped_layers=0, skipped_layers=0,
), ),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=model_name, model_name=model_name,
@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
# fmt: off # fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output" type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
# fmt: on # fmt: on
@title("LoRA Loader")
@tags("lora", "model")
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader" type: Literal["lora_loader"] = "lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") # Inputs
weight: float = Field(default=0.75, description="With what weight to apply lora") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = InputField(
clip: Optional[ClipField] = Field(description="Clip model for applying lora") default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
class Config(InvocationConfig): clip: Optional[ClipField] = InputField(
schema_extra = { default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
"ui": { )
"title": "Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None: if self.lora is None:
@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
class SDXLLoraLoaderOutput(BaseInvocationOutput): class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """SDXL LoRA Loader Output"""
# fmt: off # fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on # fmt: on
@title("SDXL LoRA Loader")
@tags("sdxl", "lora", "model")
class SDXLLoraLoaderInvocation(BaseInvocation): class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description="With what weight to apply lora") weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(
unet: Optional[UNetField] = Field(description="UNet model for applying lora") default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
clip: Optional[ClipField] = Field(description="Clip model for applying lora") )
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") clip: Optional[ClipField] = Field(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
class Config(InvocationConfig): )
schema_extra = { clip2: Optional[ClipField] = Field(
"ui": { default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
"title": "SDXL Lora Loader", )
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None: if self.lora is None:
@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
class VaeLoaderOutput(BaseInvocationOutput): class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
# fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output" type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model") # Outputs
# fmt: on vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("VAE Loader")
@tags("vae", "model")
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader" type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load") # Inputs
vae_model: VAEModelField = InputField(
# Schema customisation description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
class Config(InvocationConfig): )
schema_extra = {
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {"vae_model": "vae_model"},
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput: def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model base_model = self.vae_model.base_model

View File

@ -1,19 +1,24 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
import math
from typing import Literal from typing import Literal
from pydantic import Field, validator
import torch import torch
from invokeai.app.invocations.latent import LatentsField from pydantic import validator
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
InvocationConfig, FieldDescriptions,
InputField,
InvocationContext, InvocationContext,
OutputField,
UIType,
tags,
title,
) )
""" """
@ -61,14 +66,12 @@ Nodes
class NoiseOutput(BaseInvocationOutput): class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output""" """Invocation noise output"""
# fmt: off
type: Literal["noise_output"] = "noise_output" type: Literal["noise_output"] = "noise_output"
# Inputs # Inputs
noise: LatentsField = Field(default=None, description="The output noise") noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
width: int = Field(description="The width of the noise in pixels") width: int = OutputField(description=FieldDescriptions.width)
height: int = Field(description="The height of the noise in pixels") height: int = OutputField(description=FieldDescriptions.height)
# fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
) )
@title("Noise")
@tags("latents", "noise")
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
type: Literal["noise"] = "noise" type: Literal["noise"] = "noise"
# Inputs # Inputs
seed: int = Field( seed: int = InputField(
ge=0, ge=0,
le=SEED_MAX, le=SEED_MAX,
description="The seed to use", description=FieldDescriptions.seed,
default_factory=get_random_seed, default_factory=get_random_seed,
) )
width: int = Field( width: int = InputField(
default=512, default=512,
multiple_of=8, multiple_of=8,
gt=0, gt=0,
description="The width of the resulting noise", description=FieldDescriptions.width,
) )
height: int = Field( height: int = InputField(
default=512, default=512,
multiple_of=8, multiple_of=8,
gt=0, gt=0,
description="The height of the resulting noise", description=FieldDescriptions.height,
) )
use_cpu: bool = Field( use_cpu: bool = InputField(
default=True, default=True,
description="Use CPU for noise generation (for reproducible results across platforms)", description="Use CPU for noise generation (for reproducible results across platforms)",
) )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Noise",
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True) @validator("seed", pre=True)
def modulo_seed(cls, v): def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""

View File

@ -1,37 +1,43 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) # Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
import inspect
import re
from contextlib import ExitStack from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import re
import inspect
from pydantic import BaseModel, Field, validator
import torch
import numpy as np import numpy as np
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from ..models.image import ImageCategory, ImageField, ResourceOrigin from tqdm import tqdm
from ...backend.model_management import ONNXModelPatcher
from ...backend.util import choose_torch_device
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend import BaseModelType, ModelType, SubModelType from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from tqdm import tqdm from ..models.image import ImageCategory, ResourceOrigin
from .model import ClipField from .baseinvocation import (
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES BaseInvocation,
from .compel import CompelOutput BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UIComponent,
UIType,
tags,
title,
)
from .controlnet_image_processors import ControlField
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
from .model import ClipField, ModelInfo, UNetField, VaeField
ORT_TO_NP_TYPE = { ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_, "tensor(bool)": np.bool_,
@ -51,13 +57,15 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@title("ONNX Prompt (Raw)")
@tags("onnx", "prompt")
class ONNXPromptInvocation(BaseInvocation): class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx" type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = Field(default="", description="Prompt") prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = Field(None, description="Clip to use") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
) )
@ -126,7 +134,7 @@ class ONNXPromptInvocation(BaseInvocation):
# TODO: hacky but works ;D maybe rename latents somehow? # TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (prompt_embeds, None)) context.services.latents.save(conditioning_name, (prompt_embeds, None))
return CompelOutput( return ConditioningOutput(
conditioning=ConditioningField( conditioning=ConditioningField(
conditioning_name=conditioning_name, conditioning_name=conditioning_name,
), ),
@ -134,25 +142,48 @@ class ONNXPromptInvocation(BaseInvocation):
# Text to image # Text to image
@title("ONNX Text to Latents")
@tags("latents", "inference", "txt2img", "onnx")
class ONNXTextToLatentsInvocation(BaseInvocation): class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx" type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs # Inputs
# fmt: off positive_conditioning: ConditioningField = InputField(
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") description=FieldDescriptions.positive_cond,
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") input=Input.Connection,
noise: Optional[LatentsField] = Field(description="The noise to use") )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") negative_conditioning: ConditioningField = InputField(
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) description=FieldDescriptions.negative_cond,
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) input=Input.Connection,
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents") )
unet: UNetField = Field(default=None, description="UNet submodel") noise: LatentsField = InputField(
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") description=FieldDescriptions.noise,
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) input=Input.Connection,
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") )
# fmt: on steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5,
ge=1,
description=FieldDescriptions.cfg_scale,
ui_type=UIType.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
)
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
)
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
description=FieldDescriptions.control,
ui_type=UIType.Control,
)
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
@validator("cfg_scale") @validator("cfg_scale")
def ge_one(cls, v): def ge_one(cls, v):
@ -166,20 +197,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
}
# based on # based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -300,26 +317,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# Latent to image # Latent to image
@title("ONNX Latents to Image")
@tags("latents", "image", "vae", "onnx")
class ONNXLatentsToImageInvocation(BaseInvocation): class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx" type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: LatentsField = InputField(
vae: VaeField = Field(default=None, description="Vae submodel") description=FieldDescriptions.denoised_latents,
metadata: Optional[CoreMetadata] = Field( input=Input.Connection,
default=None, description="Optional core metadata to be written to the image"
) )
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") vae: VaeField = InputField(
description=FieldDescriptions.vae,
# Schema customisation input=Input.Connection,
class Config(InvocationConfig): )
schema_extra = { metadata: Optional[CoreMetadata] = InputField(
"ui": { default=None,
"tags": ["latents", "image"], description=FieldDescriptions.core_metadata,
}, ui_hidden=True,
} )
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -373,89 +392,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
# fmt: off # fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx" type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
vae_decoder: VaeField = Field(default=None, description="Vae submodel") vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
vae_encoder: VaeField = Field(default=None, description="Vae submodel") vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
# fmt: on # fmt: on
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
model_name = "stable-diffusion-v1-5"
base_model = BaseModelType.StableDiffusion1
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
):
raise Exception(f"Unkown model name: {model_name}!")
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeEncoder,
),
),
)
class OnnxModelField(BaseModel): class OnnxModelField(BaseModel):
"""Onnx model field""" """Onnx model field"""
@ -464,22 +407,17 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type") model_type: ModelType = Field(description="Model Type")
@title("ONNX Model Loader")
@tags("onnx", "model")
class OnnxModelLoaderInvocation(BaseInvocation): class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader" type: Literal["onnx_model_loader"] = "onnx_model_loader"
model: OnnxModelField = Field(description="The model to load") # Inputs
model: OnnxModelField = InputField(
# Schema customisation description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
class Config(InvocationConfig): )
schema_extra = {
"ui": {
"title": "Onnx Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model

View File

@ -1,73 +1,64 @@
import io import io
from typing import Literal, Optional, Any from typing import Literal, Optional
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
from easing_functions import ( from easing_functions import (
LinearInOut,
QuadEaseInOut,
QuadEaseIn,
QuadEaseOut,
CubicEaseInOut,
CubicEaseIn,
CubicEaseOut,
QuarticEaseInOut,
QuarticEaseIn,
QuarticEaseOut,
QuinticEaseInOut,
QuinticEaseIn,
QuinticEaseOut,
SineEaseInOut,
SineEaseIn,
SineEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
ExponentialEaseInOut,
ExponentialEaseIn,
ExponentialEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
BackEaseIn, BackEaseIn,
BackEaseInOut, BackEaseInOut,
BackEaseOut, BackEaseOut,
BounceEaseIn, BounceEaseIn,
BounceEaseInOut, BounceEaseInOut,
BounceEaseOut, BounceEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
CubicEaseIn,
CubicEaseInOut,
CubicEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
ExponentialEaseIn,
ExponentialEaseInOut,
ExponentialEaseOut,
LinearInOut,
QuadEaseIn,
QuadEaseInOut,
QuadEaseOut,
QuarticEaseIn,
QuarticEaseInOut,
QuarticEaseOut,
QuinticEaseIn,
QuinticEaseInOut,
QuinticEaseOut,
SineEaseIn,
SineEaseInOut,
SineEaseOut,
) )
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from pydantic import BaseModel, Field
from invokeai.app.invocations.primitives import FloatCollectionOutput
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
@title("Float Range")
@tags("math", "range")
class FloatLinearRangeInvocation(BaseInvocation): class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range""" """Creates a range"""
type: Literal["float_range"] = "float_range" type: Literal["float_range"] = "float_range"
# Inputs # Inputs
start: float = Field(default=5, description="The first value of the range") start: float = InputField(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range") stop: float = InputField(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)") steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
}
def invoke(self, context: InvocationContext) -> FloatCollectionOutput: def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps)) param_list = list(np.linspace(self.start, self.stop, self.steps))
@ -108,37 +99,32 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut, "BounceInOut": BounceEaseInOut,
} }
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any] # actually I think for now could just use CollectionOutput (which is list[Any]
@title("Step Param Easing")
@tags("step", "easing")
class StepParamEasingInvocation(BaseInvocation): class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps""" """Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing" type: Literal["step_param_easing"] = "step_param_easing"
# Inputs # Inputs
# fmt: off easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use") num_steps: int = InputField(default=20, description="number of denoising steps")
num_steps: int = Field(default=20, description="number of denoising steps") start_value: float = InputField(default=0.0, description="easing starting value")
start_value: float = Field(default=0.0, description="easing starting value") end_value: float = InputField(default=1.0, description="easing ending value")
end_value: float = Field(default=1.0, description="easing ending value") start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing") end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start # if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start") pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
# if None, then end value is used prior to easing end # if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end") post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function") mirror: bool = InputField(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing") # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot") show_easing_plot: bool = InputField(default=False, description="show easing plot")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
}
def invoke(self, context: InvocationContext) -> FloatCollectionOutput: def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False log_diagnostics = False

View File

@ -1,83 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
# fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
}
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
# fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
}
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)
class StringOutput(BaseInvocationOutput):
"""A string output"""
type: Literal["string_output"] = "string_output"
text: str = Field(default=None, description="The output string")
class ParamStringInvocation(BaseInvocation):
"""A string parameter"""
type: Literal["param_string"] = "param_string"
text: str = Field(default="", description="The string value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
}
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
prompt: str = Field(default="", description="The prompt value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
}
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@ -0,0 +1,494 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple, Union
from anyio import Condition
from pydantic import BaseModel, Field
import torch
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
UIType,
tags,
title,
)
"""
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
- primitive nodes
- primitive outputs
- primitive collection outputs
"""
# region Boolean
class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
a: bool = OutputField(description="The output boolean")
class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans"""
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
)
@title("Boolean Primitive")
@tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
type: Literal["boolean"] = "boolean"
# Inputs
a: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(a=self.a)
@title("Boolean Primitive Collection")
@tags("primitives", "boolean", "collection")
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
type: Literal["boolean_collection"] = "boolean_collection"
# Inputs
collection: list[bool] = InputField(
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
# endregion
# region Integer
class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
a: int = OutputField(description="The output integer")
class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers"""
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
)
@title("Integer Primitive")
@tags("primitives", "integer")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
type: Literal["integer"] = "integer"
# Inputs
a: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a)
@title("Integer Primitive Collection")
@tags("primitives", "integer", "collection")
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
type: Literal["integer_collection"] = "integer_collection"
# Inputs
collection: list[int] = InputField(
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
# endregion
# region Float
class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
a: float = OutputField(description="The output float")
class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
)
@title("Float Primitive")
@tags("primitives", "float")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
type: Literal["float"] = "float"
# Inputs
param: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(a=self.param)
@title("Float Primitive Collection")
@tags("primitives", "float", "collection")
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
type: Literal["float_collection"] = "float_collection"
# Inputs
collection: list[float] = InputField(
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
# endregion
# region String
class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
text: str = OutputField(description="The output string")
class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
)
@title("String Primitive")
@tags("primitives", "string")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
type: Literal["string"] = "string"
# Inputs
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
@title("String Primitive Collection")
@tags("primitives", "string", "collection")
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
type: Literal["string_collection"] = "string_collection"
# Inputs
collection: list[str] = InputField(
default=0, description="The collection of string values", ui_type=UIType.StringCollection
)
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
# endregion
# region Image
class ImageField(BaseModel):
"""An image primitive field"""
image_name: str = Field(description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["image_output"] = "image_output"
image: ImageField = OutputField(description="The output image")
width: int = OutputField(description="The width of the image in pixels")
height: int = OutputField(description="The height of the image in pixels")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
)
@title("Image Primitive")
@tags("primitives", "image")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
# Metadata
type: Literal["image"] = "image"
# Inputs
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@title("Image Primitive Collection")
@tags("primitives", "image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
type: Literal["image_collection"] = "image_collection"
# Inputs
collection: list[ImageField] = InputField(
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
)
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
# endregion
# region Latents
class LatentsField(BaseModel):
"""A latents tensor primitive field"""
latents_name: str = Field(description="The name of the latents")
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor"""
type: Literal["latents_output"] = "latents_output"
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
class LatentsCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of latents tensors"""
type: Literal["latents_collection_output"] = "latents_collection_output"
collection: list[LatentsField] = OutputField(
default_factory=list,
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
)
@title("Latents Primitive")
@tags("primitives", "latents")
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
type: Literal["latents"] = "latents"
# Inputs
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Primitive Collection")
@tags("primitives", "latents", "collection")
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
type: Literal["latents_collection"] = "latents_collection"
# Inputs
collection: list[LatentsField] = InputField(
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
)
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
return LatentsCollectionOutput(collection=self.collection)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
# endregion
# region Color
class ColorField(BaseModel):
"""A color primitive field"""
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ColorOutput(BaseInvocationOutput):
"""Base class for nodes that output a single color"""
type: Literal["color_output"] = "color_output"
color: ColorField = OutputField(description="The output color")
class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors"""
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
)
@title("Color Primitive")
@tags("primitives", "color")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
type: Literal["color"] = "color"
# Inputs
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context: InvocationContext) -> ColorOutput:
return ColorOutput(color=self.color)
# endregion
# region Conditioning
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
type: Literal["conditioning_output"] = "conditioning_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
class ConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
# Outputs
collection: list[ConditioningField] = OutputField(
default_factory=list,
description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
@title("Conditioning Primitive")
@tags("primitives", "conditioning")
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
type: Literal["conditioning"] = "conditioning"
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Primitive Collection")
@tags("primitives", "conditioning", "collection")
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
type: Literal["conditioning_collection"] = "conditioning_collection"
# Inputs
collection: list[ConditioningField] = InputField(
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
)
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
return ConditioningCollectionOutput(collection=self.collection)
# endregion

View File

@ -1,59 +1,28 @@
from os.path import exists from os.path import exists
from typing import Literal, Optional from typing import Literal, Optional, Union
import numpy as np import numpy as np
from pydantic import Field, validator from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from pydantic import validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from invokeai.app.invocations.primitives import StringCollectionOutput
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
# fmt: off
type: Literal["prompt"] = "prompt"
prompt: str = Field(default=None, description="The output prompt")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"prompt",
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
@title("Dynamic Prompt")
@tags("prompt", "collection")
class DynamicPromptInvocation(BaseInvocation): class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt" type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
class Config(InvocationConfig): # Inputs
schema_extra = { prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]}, max_prompts: int = InputField(default=1, description="The number of prompts to generate")
} combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context: InvocationContext) -> PromptCollectionOutput: def invoke(self, context: InvocationContext) -> StringCollectionOutput:
if self.combinatorial: if self.combinatorial:
generator = CombinatorialPromptGenerator() generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
@ -61,27 +30,26 @@ class DynamicPromptInvocation(BaseInvocation):
generator = RandomPromptGenerator() generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts) prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return StringCollectionOutput(collection=prompts)
@title("Prompts from File")
@tags("prompt", "file")
class PromptsFromFileInvocation(BaseInvocation): class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file""" """Loads prompts from a text file"""
# fmt: off type: Literal["prompt_from_file"] = "prompt_from_file"
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs # Inputs
file_path: str = Field(description="Path to prompt text file") file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt") pre_prompt: Optional[str] = InputField(
post_prompt: Optional[str] = Field(description="String to append to each prompt") default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from") )
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)") post_prompt: Optional[str] = InputField(
# fmt: on default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
)
class Config(InvocationConfig): start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
schema_extra = { max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
}
@validator("file_path") @validator("file_path")
def file_path_exists(cls, v): def file_path_exists(cls, v):
@ -89,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
raise ValueError(FileNotFoundError) raise ValueError(FileNotFoundError)
return v return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int): def promptsFromFile(
self,
file_path: str,
pre_prompt: Union[str, None],
post_prompt: Union[str, None],
start_line: int,
max_prompts: int,
):
prompts = [] prompts = []
start_line -= 1 start_line -= 1
end_line = start_line + max_prompts end_line = start_line + max_prompts
@ -103,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
break break
return prompts return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput: def invoke(self, context: InvocationContext) -> StringCollectionOutput:
prompts = self.promptsFromFile( prompts = self.promptsFromFile(
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
) )
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return StringCollectionOutput(collection=prompts)

View File

@ -1,55 +1,55 @@
import torch
from typing import Literal from typing import Literal
from pydantic import Field
from ...backend.model_management import ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import (
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
class SDXLModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output""" """SDXL base model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
# fmt: on
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output""" """SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: VaeField = Field(default=None, description="Vae submodel") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
# fmt: on
@title("SDXL Main Model Loader")
@tags("model", "sdxl")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader" type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
model: MainModelField = Field(description="The model to load") # Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision? # TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
) )
@title("SDXL Refiner Model Loader")
@tags("model", "sdxl", "refiner")
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load") # Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
)
# TODO: precision? # TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "refiner_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name

View File

@ -6,13 +6,12 @@ import cv2 as cv
import numpy as np import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .image import ImageOutput
# TODO: Populate this from disk? # TODO: Populate this from disk?
# TODO: Use model manager to load? # TODO: Use model manager to load?
@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
] ]
@title("Upscale (RealESRGAN)")
@tags("esrgan", "upscale")
class ESRGANInvocation(BaseInvocation): class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN.""" """Upscales an image using RealESRGAN."""
type: Literal["esrgan"] = "esrgan" type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
class Config(InvocationConfig): # Inputs
schema_extra = { image: ImageField = InputField(description="The input image")
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]}, model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)

View File

@ -1,31 +1,8 @@
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel): class ProgressImage(BaseModel):
@ -36,50 +13,6 @@ class ProgressImage(BaseModel):
dataURL: str = Field(description="The image data as a b64 data URL") dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum): class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image). """The origin of a resource (eg image).

View File

@ -2,7 +2,7 @@ from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocat
from ..invocations.image import ImageNSFWBlurInvocation from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.primitives import IntegerInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
description="Converts text to an image", description="Converts text to an image",
graph=Graph( graph=Graph(
nodes={ nodes={
"width": ParamIntInvocation(id="width", a=512), "width": IntegerInvocation(id="width", a=512),
"height": ParamIntInvocation(id="height", a=512), "height": IntegerInvocation(id="height", a=512),
"seed": ParamIntInvocation(id="seed", a=-1), "seed": IntegerInvocation(id="seed", a=-1),
"3": NoiseInvocation(id="3"), "3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"), "4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"), "5": CompelInvocation(id="5"),

View File

@ -3,16 +3,7 @@
import copy import copy
import itertools import itertools
import uuid import uuid
from typing import ( from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
Annotated,
Any,
Literal,
Optional,
Union,
get_args,
get_origin,
get_type_hints,
)
import networkx as nx import networkx as nx
from pydantic import BaseModel, root_validator, validator from pydantic import BaseModel, root_validator, validator
@ -22,7 +13,11 @@ from ..invocations import *
from ..invocations.baseinvocation import ( from ..invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Input,
InputField,
InvocationContext, InvocationContext,
OutputField,
UIType,
) )
# in 3.10 this would be "from types import NoneType" # in 3.10 this would be "from types import NoneType"
@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
type: Literal["iterate_output"] = "iterate_output" type: Literal["iterate_output"] = "iterate_output"
item: Any = Field(description="The item being iterated over") item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
class Config: )
schema_extra = {
"required": [
"type",
"item",
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
type: Literal["iterate"] = "iterate" type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) collection: list[Any] = InputField(
index: int = Field(description="The index, will be provided on executed iterators", default=0) description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput: def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values""" """Produces the outputs as values"""
@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
class CollectInvocationOutput(BaseInvocationOutput): class CollectInvocationOutput(BaseInvocationOutput):
type: Literal["collect_output"] = "collect_output" type: Literal["collect_output"] = "collect_output"
collection: list[Any] = Field(description="The collection of input items") collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type=UIType.Collection
class Config: )
schema_extra = {
"required": [
"type",
"collection",
]
}
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
type: Literal["collect"] = "collect" type: Literal["collect"] = "collect"
item: Any = Field( item: Any = InputField(
description="The item to collect (all inputs must be of the same type)", description="The item to collect (all inputs must be of the same type)",
default=None, ui_type=UIType.CollectionItem,
title="Collection Item",
input=Input.Connection,
) )
collection: list[Any] = Field( collection: list[Any] = InputField(
description="The collection, will be provided on execution", description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
default_factory=list,
) )
def invoke(self, context: InvocationContext) -> CollectInvocationOutput: def invoke(self, context: InvocationContext) -> CollectInvocationOutput:

View File

@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
"created_at", "created_at",
"updated_at", "updated_at",
"deleted_at", "deleted_at",
"starred",
], ],
) )
) )
@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[dict], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
starred: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices. # Create the `images` table indices.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`. # Add trigger for `updated_at`.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(changes.is_intermediate, image_name), (changes.is_intermediate, image_name),
) )
# Change the image's `starred`` state
if changes.starred is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -397,7 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_params.append(board_id) query_params.append(board_id)
query_pagination = """--sql query_pagination = """--sql
ORDER BY images.created_at DESC LIMIT ? OFFSET ? ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
""" """
# Final images query with pagination # Final images query with pagination
@ -500,6 +529,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[dict], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
starred: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = None if metadata is None else json.dumps(metadata) metadata_json = None if metadata is None else json.dumps(metadata)
@ -515,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id, node_id,
session_id, session_id,
metadata, metadata,
is_intermediate is_intermediate,
starred
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
@ -529,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id, session_id,
metadata_json, metadata_json,
is_intermediate, is_intermediate,
starred,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull):
description="The node ID that generated this image, if it is a generated image.", description="The node ID that generated this image, if it is a generated image.",
) )
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
starred: bool = Field(description="Whether this image is starred.")
"""Whether this image is starred."""
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
- `image_category`: change the category of an image - `image_category`: change the category of an image
- `session_id`: change the session associated with an image - `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag - `is_intermediate`: change the image's `is_intermediate` flag
- `starred`: change whether the image is starred
""" """
image_category: Optional[ImageCategory] = Field(description="The image's new category.") image_category: Optional[ImageCategory] = Field(description="The image's new category.")
@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
"""The image's new session ID.""" """The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.") is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
"""The image's new `is_intermediate` flag.""" """The image's new `is_intermediate` flag."""
starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state")
"""The image's new `starred` state."""
class ImageUrlsDTO(BaseModelExcludeNull): class ImageUrlsDTO(BaseModelExcludeNull):
@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
starred = image_dict.get("starred", False)
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
starred=starred,
) )

View File

@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
with statistics.collect_stats(invocation, graph_execution_state.id): with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke( # use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
outputs = invocation.invoke_internal(
InvocationContext( InvocationContext(
services=self.__invoker.services, services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,

View File

@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T: def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0] item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item) parsed = parse_raw_as(item_type, item)
return parsed
def set(self, item: T): def set(self, item: T):
try: try:

View File

@ -109,7 +109,7 @@ class ModelMerger(object):
# pick up the first model's vae # pick up the first model's vae
if mod == model_names[0]: if mod == model_names[0]:
vae = info.get("vae") vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]]) model_paths.extend([(config.root_path / info["path"]).as_posix()])
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}") logger.debug(f"interp = {interp}, merge_method={merge_method}")
@ -120,11 +120,11 @@ class ModelMerger(object):
else config.models_path / base_model.value / ModelType.Main.value else config.models_path / base_model.value / ModelType.Main.value
) )
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True) merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict( attributes = dict(
path=str(dump_path), path=dump_path,
description=f"Merge of models {', '.join(model_names)}", description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers", model_format="diffusers",
variant=ModelVariantType.Normal.value, variant=ModelVariantType.Normal.value,

View File

@ -481,9 +481,19 @@ class ControlNetFolderProbe(FolderProbeBase):
with open(config_file, "r") as file: with open(config_file, "r") as file:
config = json.load(file) config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768 # no obvious way to distinguish between sd2-base and sd2-768
return ( dimension = config["cross_attention_dim"]
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2 base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
) )
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase): class LoRAFolderProbe(FolderProbeBase):

View File

@ -1,4 +0,0 @@
"""
Initialization file for the web backend.
"""
from .invoke_ai_web_server import InvokeAIWebServer

File diff suppressed because it is too large Load Diff

View File

@ -1,56 +0,0 @@
import argparse
import os
from ...args import PRECISION_CHOICES
def create_cmd_parser():
parser = argparse.ArgumentParser(description="InvokeAI web UI")
parser.add_argument(
"--host",
type=str,
help="The host to serve on",
default="localhost",
)
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
parser.add_argument(
"--cors",
nargs="*",
type=str,
help="Additional allowed origins, comma-separated",
)
parser.add_argument(
"--embedding_path",
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
)
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
# parser.add_argument(
# "--output_dir",
# default="outputs/",
# type=str,
# help="Directory for output images",
# )
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Enables verbose logging",
)
parser.add_argument(
"--precision",
dest="precision",
type=str,
choices=PRECISION_CHOICES,
metavar="PRECISION",
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default="auto",
)
parser.add_argument(
"--free_gpu_mem",
dest="free_gpu_mem",
action="store_true",
help="Force free gpu memory before final decoding",
)
return parser

View File

@ -1,113 +0,0 @@
from typing import Literal, Union
from PIL import Image, ImageChops
from PIL.Image import Image as ImageType
# https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
def check_for_any_transparency(img: Union[ImageType, str]) -> bool:
if type(img) is str:
img = Image.open(str)
if img.info.get("transparency", None) is not None:
return True
if img.mode == "P":
transparent = img.info.get("transparency", -1)
for _, index in img.getcolors():
if index == transparent:
return True
elif img.mode == "RGBA":
extrema = img.getextrema()
if extrema[3][0] < 255:
return True
return False
def get_canvas_generation_mode(
init_img: Union[ImageType, str], init_mask: Union[ImageType, str]
) -> Literal["txt2img", "outpainting", "inpainting", "img2img",]:
if type(init_img) is str:
init_img = Image.open(init_img)
if type(init_mask) is str:
init_mask = Image.open(init_mask)
init_img = init_img.convert("RGBA")
# Get alpha from init_img
init_img_alpha = init_img.split()[-1]
init_img_alpha_mask = init_img_alpha.convert("L")
init_img_has_transparency = check_for_any_transparency(init_img)
if init_img_has_transparency:
init_img_is_fully_transparent = True if init_img_alpha_mask.getbbox() is None else False
"""
Mask images are white in areas where no change should be made, black where changes
should be made.
"""
# Fit the mask to init_img's size and convert it to greyscale
init_mask = init_mask.resize(init_img.size).convert("L")
"""
PIL.Image.getbbox() returns the bounding box of non-zero areas of the image, so we first
invert the mask image so that masked areas are white and other areas black == zero.
getbbox() now tells us if the are any masked areas.
"""
init_mask_bbox = ImageChops.invert(init_mask).getbbox()
init_mask_exists = False if init_mask_bbox is None else True
if init_img_has_transparency:
if init_img_is_fully_transparent:
return "txt2img"
else:
return "outpainting"
else:
if init_mask_exists:
return "inpainting"
else:
return "img2img"
def main():
# Testing
init_img_opaque = "test_images/init-img_opaque.png"
init_img_partial_transparency = "test_images/init-img_partial_transparency.png"
init_img_full_transparency = "test_images/init-img_full_transparency.png"
init_mask_no_mask = "test_images/init-mask_no_mask.png"
init_mask_has_mask = "test_images/init-mask_has_mask.png"
print(
"OPAQUE IMAGE, NO MASK, expect img2img, got ",
get_canvas_generation_mode(init_img_opaque, init_mask_no_mask),
)
print(
"IMAGE WITH TRANSPARENCY, NO MASK, expect outpainting, got ",
get_canvas_generation_mode(init_img_partial_transparency, init_mask_no_mask),
)
print(
"FULLY TRANSPARENT IMAGE NO MASK, expect txt2img, got ",
get_canvas_generation_mode(init_img_full_transparency, init_mask_no_mask),
)
print(
"OPAQUE IMAGE, WITH MASK, expect inpainting, got ",
get_canvas_generation_mode(init_img_opaque, init_mask_has_mask),
)
print(
"IMAGE WITH TRANSPARENCY, WITH MASK, expect outpainting, got ",
get_canvas_generation_mode(init_img_partial_transparency, init_mask_has_mask),
)
print(
"FULLY TRANSPARENT IMAGE WITH MASK, expect txt2img, got ",
get_canvas_generation_mode(init_img_full_transparency, init_mask_has_mask),
)
if __name__ == "__main__":
main()

View File

@ -1,82 +0,0 @@
import argparse
from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [
"ddim",
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
]
def parameters_to_command(params):
"""
Converts dict of parameters into a `invoke.py` REPL command.
"""
switches = list()
if "prompt" in params:
switches.append(f'"{params["prompt"]}"')
if "steps" in params:
switches.append(f'-s {params["steps"]}')
if "seed" in params:
switches.append(f'-S {params["seed"]}')
if "width" in params:
switches.append(f'-W {params["width"]}')
if "height" in params:
switches.append(f'-H {params["height"]}')
if "cfg_scale" in params:
switches.append(f'-C {params["cfg_scale"]}')
if "sampler_name" in params:
switches.append(f'-A {params["sampler_name"]}')
if "seamless" in params and params["seamless"] == True:
switches.append(f"--seamless")
if "hires_fix" in params and params["hires_fix"] == True:
switches.append(f"--hires")
if "init_img" in params and len(params["init_img"]) > 0:
switches.append(f'-I {params["init_img"]}')
if "init_mask" in params and len(params["init_mask"]) > 0:
switches.append(f'-M {params["init_mask"]}')
if "init_color" in params and len(params["init_color"]) > 0:
switches.append(f'--init_color {params["init_color"]}')
if "strength" in params and "init_img" in params:
switches.append(f'-f {params["strength"]}')
if "fit" in params and params["fit"] == True:
switches.append(f"--fit")
if "facetool" in params:
switches.append(f'-ft {params["facetool"]}')
if "facetool_strength" in params and params["facetool_strength"]:
switches.append(f'-G {params["facetool_strength"]}')
elif "gfpgan_strength" in params and params["gfpgan_strength"]:
switches.append(f'-G {params["gfpgan_strength"]}')
if "codeformer_fidelity" in params:
switches.append(f'-cf {params["codeformer_fidelity"]}')
if "upscale" in params and params["upscale"]:
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
if "variation_amount" in params and params["variation_amount"] > 0:
switches.append(f'-v {params["variation_amount"]}')
if "with_variations" in params:
seed_weight_pairs = ",".join(f"{seed}:{weight}" for seed, weight in params["with_variations"])
switches.append(f"-V {seed_weight_pairs}")
return " ".join(switches)

View File

@ -1,47 +0,0 @@
def parse_seed_weights(seed_weights):
"""
Accepts seed weights as string in "12345:0.1,23456:0.2,3456:0.3" format
Validates them
If valid: returns as [[12345, 0.1], [23456, 0.2], [3456, 0.3]]
If invalid: returns False
"""
# Must be a string
if not isinstance(seed_weights, str):
return False
# String must not be empty
if len(seed_weights) == 0:
return False
pairs = []
for pair in seed_weights.split(","):
split_values = pair.split(":")
# Seed and weight are required
if len(split_values) != 2:
return False
if len(split_values[0]) == 0 or len(split_values[1]) == 1:
return False
# Try casting the seed to int and weight to float
try:
seed = int(split_values[0])
weight = float(split_values[1])
except ValueError:
return False
# Seed must be 0 or above
if not seed >= 0:
return False
# Weight must be between 0 and 1
if not (weight >= 0 and weight <= 1):
return False
# This pair is valid
pairs.append([seed, weight])
# All pairs are valid
return pairs

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-deaa1f26.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-b4489359.js";var za=String.raw,Ca=za` import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-2c171c8f.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-971c0572.js";var za=String.raw,Ca=za`
:root, :root,
:host { :host {
--chakra-vh: 100vh; --chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-deaa1f26.js"></script> <script type="module" crossorigin src="./assets/index-2c171c8f.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -503,6 +503,9 @@
"hiresStrength": "High Res Strength", "hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size", "imageFit": "Fit Initial Image To Output Size",
"codeformerFidelity": "Fidelity", "codeformerFidelity": "Fidelity",
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamSize": "Seam Size", "seamSize": "Seam Size",
"seamBlur": "Seam Blur", "seamBlur": "Seam Blur",
"seamStrength": "Seam Strength", "seamStrength": "Seam Strength",

View File

@ -61,6 +61,7 @@
"@dagrejs/graphlib": "^2.1.13", "@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8", "@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1", "@dnd-kit/modifiers": "^6.0.1",
"@dnd-kit/utilities": "^3.2.1",
"@emotion/react": "^11.11.1", "@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0", "@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1", "@floating-ui/react-dom": "^2.0.1",

View File

@ -0,0 +1,34 @@
export const COLORS = {
reset: '\x1b[0m',
bright: '\x1b[1m',
dim: '\x1b[2m',
underscore: '\x1b[4m',
blink: '\x1b[5m',
reverse: '\x1b[7m',
hidden: '\x1b[8m',
fg: {
black: '\x1b[30m',
red: '\x1b[31m',
green: '\x1b[32m',
yellow: '\x1b[33m',
blue: '\x1b[34m',
magenta: '\x1b[35m',
cyan: '\x1b[36m',
white: '\x1b[37m',
gray: '\x1b[90m',
crimson: '\x1b[38m',
},
bg: {
black: '\x1b[40m',
red: '\x1b[41m',
green: '\x1b[42m',
yellow: '\x1b[43m',
blue: '\x1b[44m',
magenta: '\x1b[45m',
cyan: '\x1b[46m',
white: '\x1b[47m',
gray: '\x1b[100m',
crimson: '\x1b[48m',
},
};

View File

@ -1,23 +1,83 @@
import fs from 'node:fs'; import fs from 'node:fs';
import openapiTS from 'openapi-typescript'; import openapiTS from 'openapi-typescript';
import { COLORS } from './colors.js';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json'; const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.d.ts'; const OUTPUT_FILE = 'src/services/api/schema.d.ts';
async function main() { async function main() {
process.stdout.write( process.stdout.write(
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...` `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
); );
const types = await openapiTS(OPENAPI_URL, { const types = await openapiTS(OPENAPI_URL, {
exportType: true, exportType: true,
transform: (schemaObject) => { transform: (schemaObject, metadata) => {
if ('format' in schemaObject && schemaObject.format === 'binary') { if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob'; return schemaObject.nullable ? 'Blob | null' : 'Blob';
} }
/**
* Because invocations may have required fields that accept connection input, the generated
* types may be incorrect.
*
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
* connection input, it should be optional on instantiation of the field.
*
* To handle this, the schema exposes an `input` property that can be used to determine if the
* field accepts connection input. If it does, we can make the field optional.
*/
// Check if we are generating types for an invocation
const isInvocationPath = metadata.path.match(
/^#\/components\/schemas\/\w*Invocation$/
);
const hasInvocationProperties =
schemaObject.properties &&
['id', 'is_intermediate', 'type'].every(
(prop) => prop in schemaObject.properties
);
if (isInvocationPath && hasInvocationProperties) {
// We only want to make fields optional if they are required
if (!Array.isArray(schemaObject?.required)) {
schemaObject.required = ['id', 'type'];
return;
}
schemaObject.required.forEach((prop) => {
const acceptsConnection = ['any', 'connection'].includes(
schemaObject.properties?.[prop]?.['input']
);
if (acceptsConnection) {
// remove this prop from the required array
const invocationName = metadata.path.split('/').pop();
console.log(
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
);
schemaObject.required = schemaObject.required.filter(
(r) => r !== prop
);
}
});
schemaObject.required = [
...new Set(schemaObject.required.concat(['id', 'type'])),
];
return;
}
// if (
// 'input' in schemaObject &&
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
// ) {
// schemaObject.required = false;
// }
}, },
}); });
fs.writeFileSync(OUTPUT_FILE, types); fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(` OK!\r\n`); process.stdout.write(`\nOK!\r\n`);
} }
main(); main();

View File

@ -1,8 +1,12 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import {
ctrlKeyPressed,
metaKeyPressed,
shiftKeyPressed,
} from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
@ -16,11 +20,11 @@ import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector( const globalHotkeysSelector = createSelector(
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui], [stateSelector],
(hotkeys, ui) => { ({ hotkeys, ui }) => {
const { shift } = hotkeys; const { shift, ctrl, meta } = hotkeys;
const { shouldPinParametersPanel, shouldPinGallery } = ui; const { shouldPinParametersPanel, shouldPinGallery } = ui;
return { shift, shouldPinGallery, shouldPinParametersPanel }; return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
}, },
{ {
memoizeOptions: { memoizeOptions: {
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
*/ */
const GlobalHotkeys: React.FC = () => { const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector( const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
globalHotkeysSelector useAppSelector(globalHotkeysSelector);
);
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
useHotkeys( useHotkeys(
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
} else { } else {
shift && dispatch(shiftKeyPressed(false)); shift && dispatch(shiftKeyPressed(false));
} }
if (isHotkeyPressed('ctrl')) {
!ctrl && dispatch(ctrlKeyPressed(true));
} else {
ctrl && dispatch(ctrlKeyPressed(false));
}
if (isHotkeyPressed('meta')) {
!meta && dispatch(metaKeyPressed(true));
} else {
meta && dispatch(metaKeyPressed(false));
}
}, },
{ keyup: true, keydown: true }, { keyup: true, keydown: true },
[shift] [shift, ctrl, meta]
); );
useHotkeys('o', () => { useHotkeys('o', () => {

View File

@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware'; import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading'; import Loading from '../../common/components/Loading/Loading';
import '../../i18n'; import '../../i18n';
import ImageDndContext from './ImageDnd/ImageDndContext'; import AppDndContext from '../../features/dnd/components/AppDndContext';
const App = lazy(() => import('./App')); const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -80,9 +80,9 @@ const InvokeAIUI = ({
<Provider store={store}> <Provider store={store}>
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<ImageDndContext> <AppDndContext>
<App config={config} headerComponent={headerComponent} /> <App config={config} headerComponent={headerComponent} />
</ImageDndContext> </AppDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>
</Provider> </Provider>

View File

@ -19,7 +19,8 @@ type LoggerNamespace =
| 'nodes' | 'nodes'
| 'system' | 'system'
| 'socketio' | 'socketio'
| 'session'; | 'session'
| 'dnd';
export const logger = (namespace: LoggerNamespace) => export const logger = (namespace: LoggerNamespace) =>
$logger.get().child({ namespace }); $logger.get().child({ namespace });

View File

@ -15,7 +15,7 @@ export const actionsDenylist = [
'socket/socketGeneratorProgress', 'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress', 'socket/appSocketGeneratorProgress',
// every time user presses shift // every time user presses shift
'hotkeys/shiftKeyPressed', // 'hotkeys/shiftKeyPressed',
// this happens after every state change // this happens after every state change
'@@REMEMBER_PERSISTED', '@@REMEMBER_PERSISTED',
]; ];

View File

@ -15,6 +15,7 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery'; import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
@ -27,8 +28,8 @@ import {
addImageDeletedFulfilledListener, addImageDeletedFulfilledListener,
addImageDeletedPendingListener, addImageDeletedPendingListener,
addImageDeletedRejectedListener, addImageDeletedRejectedListener,
addRequestedSingleImageDeletionListener,
addRequestedMultipleImageDeletionListener, addRequestedMultipleImageDeletionListener,
addRequestedSingleImageDeletionListener,
} from './listeners/imageDeleted'; } from './listeners/imageDeleted';
import { addImageDroppedListener } from './listeners/imageDropped'; import { addImageDroppedListener } from './listeners/imageDropped';
import { import {
@ -79,6 +80,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -120,6 +123,10 @@ addImageDeletedRejectedListener();
addDeleteBoardAndImagesFulfilledListener(); addDeleteBoardAndImagesFulfilledListener();
addImageToDeleteSelectedListener(); addImageToDeleteSelectedListener();
// Image starred
addImagesStarredListener();
addImagesUnstarredListener();
// User Invoked // User Invoked
addUserInvokedCanvasListener(); addUserInvokedCanvasListener();
addUserInvokedNodesListener(); addUserInvokedNodesListener();
@ -129,6 +136,7 @@ addSessionReadyToInvokeListener();
// Canvas actions // Canvas actions
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
addCanvasMaskSavedToGalleryListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();

View File

@ -0,0 +1,60 @@
import { logger } from 'app/logging/logger';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasMaskSavedToGalleryListener = () => {
startAppListening({
actionCreator: canvasMaskSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: 'Problem Saving Mask',
description: 'Unable to export mask',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Mask Saved to Assets' },
},
})
);
},
});
};

View File

@ -1,16 +1,20 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import {
fieldImageValueChanged,
workflowExposedFieldAdded,
} from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../'; import { startAppListening } from '../';
import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{ export const dndDropped = createAction<{
overData: TypesafeDroppableData; overData: TypesafeDroppableData;
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
startAppListening({ startAppListening({
actionCreator: dndDropped, actionCreator: dndDropped,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch }) => {
const log = logger('images'); const log = logger('dnd');
const { activeData, overData } = action.payload; const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') { if (activeData.payloadType === 'IMAGE_DTO') {
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
{ activeData, overData }, { activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped` `Images (${activeData.payload.imageDTOs.length}) dropped`
); );
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug(
{ activeData: parseify(activeData), overData: parseify(overData) },
'Node field dropped'
);
} else { } else {
log.debug({ activeData, overData }, `Unknown payload dropped`); log.debug({ activeData, overData }, `Unknown payload dropped`);
} }
if (
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
activeData.payloadType === 'NODE_FIELD'
) {
const { nodeId, field } = activeData.payload;
dispatch(
workflowExposedFieldAdded({
nodeId,
fieldName: field.name,
})
);
}
/** /**
* Image dropped on current image * Image dropped on current image
*/ */
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
) { ) {
const { fieldName, nodeId } = overData.context; const { fieldName, nodeId } = overData.context;
dispatch( dispatch(
fieldValueChanged({ fieldImageValueChanged({
nodeId, nodeId,
fieldName, fieldName,
value: activeData.payload.imageDTO, value: activeData.payload.imageDTO,

View File

@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_NODES_IMAGE') { if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction; const { nodeId, fieldName } = postUploadAction;
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO })); dispatch(
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
);
dispatch( dispatch(
addToast({ addToast({
...DEFAULT_UPLOADED_TOAST, ...DEFAULT_UPLOADED_TOAST,

View File

@ -0,0 +1,30 @@
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
import { ImageDTO } from '../../../../../services/api/types';
export const addImagesStarredListener = () => {
startAppListening({
matcher: imagesApi.endpoints.starImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
const { updated_image_names: starredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (starredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: true,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

View File

@ -0,0 +1,30 @@
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
import { ImageDTO } from '../../../../../services/api/types';
export const addImagesUnstarredListener = () => {
startAppListening({
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
const { updated_image_names: unstarredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (unstarredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: false,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

View File

@ -15,12 +15,21 @@ import {
setShouldUseSDXLRefiner, setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models'; import {
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
startAppListening({ startAppListening({
predicate: (state, action) => predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'), !action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
); );
const currentModel = getState().generation.model; const currentModel = getState().generation.model;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some( if (models.length === 0) {
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all // No models loaded at all
dispatch(modelChanged(null)); dispatch(modelChanged(null));
return; return;
} }
const result = zMainOrOnnxModel.safeParse(firstModel); const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zMainOrOnnxModel.safeParse(models[0]);
if (!result.success) { if (!result.success) {
log.error( log.error(
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
}, },
}); });
startAppListening({ startAppListening({
predicate: (state, action) => predicate: (
action
): action is TypeGuardFor<
typeof modelsApi.endpoints.getMainModels.matchFulfilled
> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'), action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
); );
const currentModel = getState().sdxl.refinerModel; const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
const isCurrentModelAvailable = some( if (models.length === 0) {
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model &&
m?.model_type === currentModel?.model_type
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all // No models loaded at all
dispatch(refinerModelChanged(null)); dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false)); dispatch(setShouldUseSDXLRefiner(false));
return; return;
} }
const result = zSDXLRefinerModel.safeParse(firstModel); const isCurrentModelAvailable = currentModel
? models.some(
(m) =>
m.model_name === currentModel.model_name &&
m.base_model === currentModel.base_model &&
m.model_type === currentModel.model_type
)
: false;
if (isCurrentModelAvailable) {
return;
}
const result = zSDXLRefinerModel.safeParse(models[0]);
if (!result.success) { if (!result.success) {
log.error( log.error(

View File

@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
const log = logger('system'); const log = logger('system');
const schemaJSON = action.payload; const schemaJSON = action.payload;
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema'); log.debug({ schemaJSON }, 'Received OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON); const nodeTemplates = parseSchema(schemaJSON);
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedOpenAPISchema.rejected, actionCreator: receivedOpenAPISchema.rejected,
effect: () => { effect: (action) => {
const log = logger('system'); const log = logger('system');
log.error('Problem dereferencing OpenAPI Schema'); log.error(
{ error: parseify(action.error) },
'Problem retrieving OpenAPI Schema'
);
}, },
}); });
}; };

View File

@ -19,7 +19,7 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['load_image'];
export const addInvocationCompleteEventListener = () => { export const addInvocationCompleteEventListener = () => {
startAppListening({ startAppListening({

View File

@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
const log = logger('session'); const log = logger('session');
const state = getState(); const state = getState();
const graph = buildNodesGraph(state); const graph = buildNodesGraph(state.nodes);
dispatch(nodesGraphBuilt(graph)); dispatch(nodesGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Nodes graph built'); log.debug({ graph: parseify(graph) }, 'Nodes graph built');

View File

@ -1,86 +1,7 @@
import { import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
// These are old types from the model management UI
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
// export type Model = {
// status: ModelStatus;
// description: string;
// weights: string;
// config?: string;
// vae?: string;
// width?: number;
// height?: number;
// default?: boolean;
// format?: string;
// };
// export type DiffusersModel = {
// status: ModelStatus;
// description: string;
// repo_id?: string;
// path?: string;
// vae?: {
// repo_id?: string;
// path?: string;
// };
// format?: string;
// default?: boolean;
// };
// export type ModelList = Record<string, Model & DiffusersModel>;
// export type FoundModel = {
// name: string;
// location: string;
// };
// export type InvokeModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// config: string | undefined;
// weights: string | undefined;
// vae: string | undefined;
// width: number | undefined;
// height: number | undefined;
// default: boolean | undefined;
// format: string | undefined;
// };
// export type InvokeDiffusersModelConfigProps = {
// name: string | undefined;
// description: string | undefined;
// repo_id: string | undefined;
// path: string | undefined;
// default: boolean | undefined;
// format: string | undefined;
// vae: {
// repo_id: string | undefined;
// path: string | undefined;
// };
// };
// export type InvokeModelConversionProps = {
// model_name: string;
// save_location: string;
// custom_location: string | null;
// };
// export type InvokeModelMergingProps = {
// models_to_merge: string[];
// alpha: number;
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
// force: boolean;
// merged_model_name: string;
// model_merge_save_path: string | null;
// };
/** /**
* A disable-able application feature * A disable-able application feature
*/ */

View File

@ -0,0 +1,126 @@
/**
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
*
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
* With a menu open, clicking on the reactflow background element doesn't close the menu.
*
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
* straightforward to programatically close the menu.
*
* As a (hopefully temporary) workaround, we will use a dirty hack:
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
* - increment it in `onPaneClick`
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
*/
import {
Menu,
MenuButton,
MenuButtonProps,
MenuProps,
Portal,
PortalProps,
useEventListener,
} from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import * as React from 'react';
import {
MutableRefObject,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
export interface IAIContextMenuProps<T extends HTMLElement> {
renderMenu: () => JSX.Element | null;
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
menuButtonProps?: MenuButtonProps;
}
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
props: IAIContextMenuProps<T>
) {
const [isOpen, setIsOpen] = useState(false);
const [isRendered, setIsRendered] = useState(false);
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
const [position, setPosition] = useState<[number, number]>([0, 0]);
const targetRef = useRef<T>(null);
const globalContextMenuCloseTrigger = useAppSelector(
(state) => state.ui.globalContextMenuCloseTrigger
);
useEffect(() => {
if (isOpen) {
setTimeout(() => {
setIsRendered(true);
setTimeout(() => {
setIsDeferredOpen(true);
});
});
} else {
setIsDeferredOpen(false);
const timeout = setTimeout(() => {
setIsRendered(isOpen);
}, 1000);
return () => clearTimeout(timeout);
}
}, [isOpen]);
useEffect(() => {
setIsOpen(false);
setIsDeferredOpen(false);
setIsRendered(false);
}, [globalContextMenuCloseTrigger]);
useEventListener('contextmenu', (e) => {
if (
targetRef.current?.contains(e.target as HTMLElement) ||
e.target === targetRef.current
) {
e.preventDefault();
setIsOpen(true);
setPosition([e.pageX, e.pageY]);
} else {
setIsOpen(false);
}
});
const onCloseHandler = useCallback(() => {
props.menuProps?.onClose?.();
setIsOpen(false);
}, [props.menuProps]);
return (
<>
{props.children(targetRef)}
{isRendered && (
<Portal {...props.portalProps}>
<Menu
isOpen={isDeferredOpen}
gutter={0}
{...props.menuProps}
onClose={onCloseHandler}
>
<MenuButton
aria-hidden={true}
w={1}
h={1}
style={{
position: 'absolute',
left: position[0],
top: position[1],
cursor: 'default',
}}
{...props.menuButtonProps}
/>
{props.renderMenu()}
</Menu>
</Portal>
)}
</>
);
}

View File

@ -1,16 +1,11 @@
import { import {
ChakraProps, ChakraProps,
Flex, Flex,
FlexProps,
Icon, Icon,
Image, Image,
useColorMode, useColorMode,
useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIIconButton from 'common/components/IAIIconButton';
import { import {
IAILoadingImageFallback, IAILoadingImageFallback,
IAINoContentFallback, IAINoContentFallback,
@ -21,27 +16,39 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
import { import {
MouseEvent, MouseEvent,
ReactElement, ReactElement,
ReactNode,
SyntheticEvent, SyntheticEvent,
memo, memo,
useCallback, useCallback,
useState, useState,
} from 'react'; } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { FaImage, FaUpload } from 'react-icons/fa';
import { ImageDTO, PostUploadAction } from 'services/api/types'; import { ImageDTO, PostUploadAction } from 'services/api/types';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable'; import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable'; import IAIDroppable from './IAIDroppable';
import SelectionOverlay from './SelectionOverlay'; import SelectionOverlay from './SelectionOverlay';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
type IAIDndImageProps = { const defaultUploadElement = (
<Icon
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
);
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
type IAIDndImageProps = FlexProps & {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void; onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void; onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
onClick?: (event: MouseEvent<HTMLDivElement>) => void; onClick?: (event: MouseEvent<HTMLDivElement>) => void;
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
withResetIcon?: boolean;
resetIcon?: ReactElement;
resetTooltip?: string;
withMetadataOverlay?: boolean; withMetadataOverlay?: boolean;
isDragDisabled?: boolean; isDragDisabled?: boolean;
isDropDisabled?: boolean; isDropDisabled?: boolean;
@ -52,21 +59,21 @@ type IAIDndImageProps = {
fitContainer?: boolean; fitContainer?: boolean;
droppableData?: TypesafeDroppableData; droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData; draggableData?: TypesafeDraggableData;
dropLabel?: string; dropLabel?: ReactNode;
isSelected?: boolean; isSelected?: boolean;
thumbnail?: boolean; thumbnail?: boolean;
noContentFallback?: ReactElement; noContentFallback?: ReactElement;
useThumbailFallback?: boolean; useThumbailFallback?: boolean;
withHoverOverlay?: boolean; withHoverOverlay?: boolean;
children?: JSX.Element;
uploadElement?: ReactNode;
}; };
const IAIDndImage = (props: IAIDndImageProps) => { const IAIDndImage = (props: IAIDndImageProps) => {
const { const {
imageDTO, imageDTO,
onClickReset,
onError, onError,
onClick, onClick,
withResetIcon = false,
withMetadataOverlay = false, withMetadataOverlay = false,
isDropDisabled = false, isDropDisabled = false,
isDragDisabled = false, isDragDisabled = false,
@ -80,32 +87,37 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel, dropLabel,
isSelected = false, isSelected = false,
thumbnail = false, thumbnail = false,
resetTooltip = 'Reset', noContentFallback = defaultNoContentFallback,
resetIcon = <FaUndo />, uploadElement = defaultUploadElement,
noContentFallback = <IAINoContentFallback icon={FaImage} />,
useThumbailFallback, useThumbailFallback,
withHoverOverlay = false, withHoverOverlay = false,
children,
onMouseOver,
onMouseOut,
} = props; } = props;
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => { const handleMouseOver = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOver) onMouseOver(e);
setIsHovered(true); setIsHovered(true);
}, []); },
const handleMouseOut = useCallback(() => { [onMouseOver]
);
const handleMouseOut = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOut) onMouseOut(e);
setIsHovered(false); setIsHovered(false);
}, []); },
[onMouseOut]
);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction, postUploadAction,
isDisabled: isUploadDisabled, isDisabled: isUploadDisabled,
}); });
const resetIconShadow = useColorModeValue(
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
);
const uploadButtonStyles = isUploadDisabled const uploadButtonStyles = isUploadDisabled
? {} ? {}
: { : {
@ -157,11 +169,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
<IAILoadingImageFallback image={imageDTO} /> <IAILoadingImageFallback image={imageDTO} />
) )
} }
width={imageDTO.width}
height={imageDTO.height}
onError={onError} onError={onError}
draggable={false} draggable={false}
sx={{ sx={{
w: imageDTO.width,
objectFit: 'contain', objectFit: 'contain',
maxW: 'full', maxW: 'full',
maxH: 'full', maxH: 'full',
@ -196,12 +207,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
{...getUploadButtonProps()} {...getUploadButtonProps()}
> >
<input {...getUploadInputProps()} /> <input {...getUploadInputProps()} />
<Icon {uploadElement}
as={FaUpload}
sx={{
boxSize: 16,
}}
/>
</Flex> </Flex>
</> </>
)} )}
@ -213,6 +219,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
onClick={onClick} onClick={onClick}
/> />
)} )}
{children}
{!isDropDisabled && ( {!isDropDisabled && (
<IAIDroppable <IAIDroppable
data={droppableData} data={droppableData}
@ -220,30 +227,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
dropLabel={dropLabel} dropLabel={dropLabel}
/> />
)} )}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex> </Flex>
)} )}
</ImageContextMenu> </ImageContextMenu>

View File

@ -0,0 +1,46 @@
import { SystemStyleObject, useColorModeValue } from '@chakra-ui/react';
import { MouseEvent, ReactElement, memo } from 'react';
import IAIIconButton from './IAIIconButton';
type Props = {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
icon?: ReactElement;
styleOverrides?: SystemStyleObject;
};
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, styleOverrides } = props;
const resetIconShadow = useColorModeValue(
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
);
return (
<IAIIconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}
icon={icon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
...styleOverrides,
}}
/>
);
};
export default memo(IAIDndImageIcon);

View File

@ -1,22 +1,19 @@
import { Box } from '@chakra-ui/react'; import { Box, BoxProps } from '@chakra-ui/react';
import { import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
TypesafeDraggableData, import { TypesafeDraggableData } from 'features/dnd/types';
useDraggable, import { memo, useRef } from 'react';
} from 'app/components/ImageDnd/typesafeDnd';
import { MouseEvent, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = { type IAIDraggableProps = BoxProps & {
disabled?: boolean; disabled?: boolean;
data?: TypesafeDraggableData; data?: TypesafeDraggableData;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
}; };
const IAIDraggable = (props: IAIDraggableProps) => { const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, onClick } = props; const { data, disabled, ...rest } = props;
const dndId = useRef(uuidv4()); const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggable({ const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
id: dndId.current, id: dndId.current,
disabled, disabled,
data, data,
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
return ( return (
<Box <Box
onClick={onClick}
ref={setNodeRef} ref={setNodeRef}
position="absolute" position="absolute"
w="full" w="full"
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
insetInlineStart={0} insetInlineStart={0}
{...attributes} {...attributes}
{...listeners} {...listeners}
{...rest}
/> />
); );
}; };

View File

@ -1,9 +1,7 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
TypesafeDroppableData, import { TypesafeDroppableData } from 'features/dnd/types';
isValidDrop, import { isValidDrop } from 'features/dnd/util/isValidDrop';
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react'; import { ReactNode, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props; const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4()); const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({ const { isOver, setNodeRef, active } = useDroppableTypesafe({
id: dndId.current, id: dndId.current,
disabled, disabled,
data, data,

View File

@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
type IAINoImageFallbackProps = { type IAINoImageFallbackProps = {
label?: string; label?: string;
icon?: As; icon?: As | null;
boxSize?: StyleProps['boxSize']; boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx']; sx?: ChakraProps['sx'];
}; };
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
...props.sx, ...props.sx,
}} }}
> >
<Icon as={icon} boxSize={boxSize} opacity={0.7} /> {icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>} {props.label && <Text textAlign="center">{props.label}</Text>}
</Flex> </Flex>
); );

View File

@ -1,10 +1,13 @@
import { import {
Flex,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormHelperText,
FormLabel, FormLabel,
FormLabelProps, FormLabelProps,
Switch, Switch,
SwitchProps, SwitchProps,
Text,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
formControlProps?: FormControlProps; formControlProps?: FormControlProps;
formLabelProps?: FormLabelProps; formLabelProps?: FormLabelProps;
tooltip?: string; tooltip?: string;
helperText?: string;
} }
/** /**
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
formControlProps, formControlProps,
formLabelProps, formLabelProps,
tooltip, tooltip,
helperText,
...rest ...rest
} = props; } = props;
return ( return (
@ -35,10 +40,11 @@ const IAISwitch = (props: IAISwitchProps) => {
<FormControl <FormControl
isDisabled={isDisabled} isDisabled={isDisabled}
width={width} width={width}
display="flex"
alignItems="center" alignItems="center"
{...formControlProps} {...formControlProps}
> >
<Flex sx={{ flexDir: 'column', w: 'full' }}>
<Flex sx={{ alignItems: 'center', w: 'full' }}>
{label && ( {label && (
<FormLabel <FormLabel
my={1} my={1}
@ -54,6 +60,13 @@ const IAISwitch = (props: IAISwitchProps) => {
</FormLabel> </FormLabel>
)} )}
<Switch {...rest} /> <Switch {...rest} />
</Flex>
{helperText && (
<FormHelperText>
<Text variant="subtext">{helperText}</Text>
</FormHelperText>
)}
</Flex>
</FormControl> </FormControl>
</Tooltip> </Tooltip>
); );

View File

@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
accent850, accent850,
accent900, accent900,
accent950, accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
] = useToken('colors', [ ] = useToken('colors', [
'base.50', 'base.50',
'base.100', 'base.100',
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
'accent.850', 'accent.850',
'accent.900', 'accent.900',
'accent.950', 'accent.950',
'baseAlpha.50',
'baseAlpha.100',
'baseAlpha.150',
'baseAlpha.200',
'baseAlpha.250',
'baseAlpha.300',
'baseAlpha.350',
'baseAlpha.400',
'baseAlpha.450',
'baseAlpha.500',
'baseAlpha.550',
'baseAlpha.600',
'baseAlpha.650',
'baseAlpha.700',
'baseAlpha.750',
'baseAlpha.800',
'baseAlpha.850',
'baseAlpha.900',
'baseAlpha.950',
'accentAlpha.50',
'accentAlpha.100',
'accentAlpha.150',
'accentAlpha.200',
'accentAlpha.250',
'accentAlpha.300',
'accentAlpha.350',
'accentAlpha.400',
'accentAlpha.450',
'accentAlpha.500',
'accentAlpha.550',
'accentAlpha.600',
'accentAlpha.650',
'accentAlpha.700',
'accentAlpha.750',
'accentAlpha.800',
'accentAlpha.850',
'accentAlpha.900',
'accentAlpha.950',
]); ]);
return { return {
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
accent850, accent850,
accent900, accent900,
accent950, accent950,
baseAlpha50,
baseAlpha100,
baseAlpha150,
baseAlpha200,
baseAlpha250,
baseAlpha300,
baseAlpha350,
baseAlpha400,
baseAlpha450,
baseAlpha500,
baseAlpha550,
baseAlpha600,
baseAlpha650,
baseAlpha700,
baseAlpha750,
baseAlpha800,
baseAlpha850,
baseAlpha900,
baseAlpha950,
accentAlpha50,
accentAlpha100,
accentAlpha150,
accentAlpha200,
accentAlpha250,
accentAlpha300,
accentAlpha350,
accentAlpha400,
accentAlpha450,
accentAlpha500,
accentAlpha550,
accentAlpha600,
accentAlpha650,
accentAlpha700,
accentAlpha750,
accentAlpha800,
accentAlpha850,
accentAlpha900,
accentAlpha950,
}; };
}; };

View File

@ -1,4 +1,10 @@
/** /**
* Serialize an object to JSON and back to a new object * Serialize an object to JSON and back to a new object
*/ */
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj)); export const parseify = (obj: unknown) => {
try {
return JSON.parse(JSON.stringify(obj));
} catch {
return 'Error parsing object';
}
};

View File

@ -2,10 +2,11 @@ import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIColorPicker from 'common/components/IAIColorPicker'; import IAIColorPicker from 'common/components/IAIColorPicker';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { import {
canvasSelector, canvasSelector,
isStagingSelector, isStagingSelector,
@ -22,7 +23,7 @@ import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaMask, FaTrash } from 'react-icons/fa'; import { FaMask, FaSave, FaTrash } from 'react-icons/fa';
export const selector = createSelector( export const selector = createSelector(
[canvasSelector, isStagingSelector], [canvasSelector, isStagingSelector],
@ -102,6 +103,10 @@ const IAICanvasMaskOptions = () => {
const handleToggleEnableMask = () => const handleToggleEnableMask = () =>
dispatch(setIsMaskEnabled(!isMaskEnabled)); dispatch(setIsMaskEnabled(!isMaskEnabled));
const handleSaveMask = async () => {
dispatch(canvasMaskSavedToGallery());
};
return ( return (
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
@ -134,6 +139,9 @@ const IAICanvasMaskOptions = () => {
pickerColor={maskColor} pickerColor={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))} onChange={(newColor) => dispatch(setMaskColor(newColor))}
/> />
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
Save Mask
</IAIButton>
<IAIButton size="sm" leftIcon={<FaTrash />} onClick={handleClearMask}> <IAIButton size="sm" leftIcon={<FaTrash />} onClick={handleClearMask}>
{t('unifiedCanvas.clearMask')} (Shift+C) {t('unifiedCanvas.clearMask')} (Shift+C)
</IAIButton> </IAIButton>

View File

@ -3,6 +3,10 @@ import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
export const canvasMaskSavedToGallery = createAction(
'canvas/canvasMaskSavedToGallery'
);
export const canvasCopiedToClipboard = createAction( export const canvasCopiedToClipboard = createAction(
'canvas/canvasCopiedToClipboard' 'canvas/canvasCopiedToClipboard'
); );

View File

@ -4,14 +4,16 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'features/dnd/types';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import { import {
ControlNetConfig, ControlNetConfig,
controlNetImageChanged, controlNetImageChanged,
@ -119,11 +121,15 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData} droppableData={droppableData}
imageDTO={controlImage} imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled} isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
resetTooltip="Reset Control Image" >
withResetIcon={Boolean(controlImage)} <IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/> />
</IAIDndImage>
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -143,10 +149,13 @@ const ControlNetImagePreview = (props: Props) => {
imageDTO={processedControlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
isDropDisabled={!isEnabled} isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage} >
resetTooltip="Reset Control Image" <IAIDndImageIcon
withResetIcon={Boolean(controlImage)} onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/> />
</IAIDndImage>
</Box> </Box>
{pendingControlImages.includes(controlNetId) && ( {pendingControlImages.includes(controlNetId) && (
<Flex <Flex

View File

@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
/** /**
* Any ControlNet Processor node, with its parameters flagged as required * Any ControlNet Processor node, with its parameters flagged as required
*/ */
export type RequiredControlNetProcessorNode = export type RequiredControlNetProcessorNode = O.Required<
| RequiredCannyImageProcessorInvocation | RequiredCannyImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation | RequiredContentShuffleImageProcessorInvocation
| RequiredHedImageProcessorInvocation | RequiredHedImageProcessorInvocation
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
| RequiredNormalbaeImageProcessorInvocation | RequiredNormalbaeImageProcessorInvocation
| RequiredOpenposeImageProcessorInvocation | RequiredOpenposeImageProcessorInvocation
| RequiredPidiImageProcessorInvocation | RequiredPidiImageProcessorInvocation
| RequiredZoeDepthImageProcessorInvocation; | RequiredZoeDepthImageProcessorInvocation,
'id'
>;
/** /**
* Type guard for CannyImageProcessorInvocation * Type guard for CannyImageProcessorInvocation

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
import { ImageUsage } from './types'; import { ImageUsage } from './types';
import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => { export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state; const { generation, canvas, nodes, controlNet } = state;
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
(obj) => obj.kind === 'image' && obj.imageName === image_name (obj) => obj.kind === 'image' && obj.imageName === image_name
); );
const isNodesImage = nodes.nodes.some((node) => { const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
return some( return some(
node.data.inputs, node.data.inputs,
(input) => (input) =>
input.type === 'image' && input.value?.image_name === image_name input.type === 'ImageField' && input.value?.image_name === image_name
); );
}); });

View File

@ -6,23 +6,18 @@ import {
useSensor, useSensor,
useSensors, useSensors,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { snapCenterToCursor } from '@dnd-kit/modifiers'; import { logger } from 'app/logging/logger';
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped'; import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react'; import { PropsWithChildren, memo, useCallback, useState } from 'react';
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
import { DndContextTypesafe } from './DndContextTypesafe';
import DragPreview from './DragPreview'; import DragPreview from './DragPreview';
import {
DndContext,
DragEndEvent,
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren; const AppDndContext = (props: PropsWithChildren) => {
const ImageDndContext = (props: ImageDndContextProps) => {
const [activeDragData, setActiveDragData] = const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null); useState<TypesafeDraggableData | null>(null);
const log = logger('images'); const log = logger('images');
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragStart = useCallback( const handleDragStart = useCallback(
(event: DragStartEvent) => { (event: DragStartEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag started'); log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag started'
);
const activeData = event.active.data.current; const activeData = event.active.data.current;
if (!activeData) { if (!activeData) {
return; return;
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const handleDragEnd = useCallback( const handleDragEnd = useCallback(
(event: DragEndEvent) => { (event: DragEndEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag ended'); log.trace(
{ dragData: parseify(event.active.data.current) },
'Drag ended'
);
const overData = event.over?.data.current; const overData = event.over?.data.current;
if (!activeDragData || !overData) { if (!activeDragData || !overData) {
return; return;
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const sensors = useSensors(mouseSensor, touchSensor); const sensors = useSensors(mouseSensor, touchSensor);
const scaledModifier = useScaledModifer();
return ( return (
<DndContext <DndContextTypesafe
onDragStart={handleDragStart} onDragStart={handleDragStart}
onDragEnd={handleDragEnd} onDragEnd={handleDragEnd}
sensors={sensors} sensors={sensors}
collisionDetection={pointerWithin} collisionDetection={pointerWithin}
autoScroll={false}
> >
{props.children} {props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}> <DragOverlay
dropAnimation={null}
modifiers={[scaledModifier]}
style={{
width: 'min-content',
height: 'min-content',
cursor: 'none',
userSelect: 'none',
// expand overlay to prevent cursor from going outside it and displaying
padding: '10rem',
}}
>
<AnimatePresence> <AnimatePresence>
{activeDragData && ( {activeDragData && (
<motion.div <motion.div
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
)} )}
</AnimatePresence> </AnimatePresence>
</DragOverlay> </DragOverlay>
</DndContext> </DndContextTypesafe>
); );
}; };
export default memo(ImageDndContext); export default memo(AppDndContext);

View File

@ -0,0 +1,6 @@
import { DndContext } from '@dnd-kit/core';
import { DndContextTypesafeProps } from '../types';
export function DndContextTypesafe(props: DndContextTypesafeProps) {
return <DndContext {...props} />;
}

View File

@ -1,6 +1,6 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd'; import { TypesafeDraggableData } from '../types';
type OverlayDragImageProps = { type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null; dragData: TypesafeDraggableData | null;
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
return null; return null;
} }
if (props.dragData.payloadType === 'NODE_FIELD') {
const { field, fieldTemplate } = props.dragData.payload;
return (
<Box
sx={{
position: 'relative',
p: 2,
px: 3,
opacity: 0.7,
bg: 'base.300',
borderRadius: 'base',
boxShadow: 'dark-lg',
whiteSpace: 'nowrap',
fontSize: 'sm',
}}
>
<Text>{field.label || fieldTemplate.title}</Text>
</Box>
);
}
if (props.dragData.payloadType === 'IMAGE_DTO') { if (props.dragData.payloadType === 'IMAGE_DTO') {
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO; const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
return ( return (
<Box <Box
sx={{ sx={{
position: 'relative', position: 'relative',
width: '100%', width: 'full',
height: '100%', height: 'full',
display: 'flex', display: 'flex',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
userSelect: 'none',
cursor: 'none',
}} }}
> >
<Image <Image
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
return ( return (
<Flex <Flex
sx={{ sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative', position: 'relative',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',

View File

@ -0,0 +1,15 @@
import { useDraggable, useDroppable } from '@dnd-kit/core';
import {
UseDraggableTypesafeArguments,
UseDraggableTypesafeReturnValue,
UseDroppableTypesafeArguments,
UseDroppableTypesafeReturnValue,
} from '../types';
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
return useDroppable(props) as UseDroppableTypesafeReturnValue;
}
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
return useDraggable(props) as UseDraggableTypesafeReturnValue;
}

View File

@ -0,0 +1,51 @@
import type { Modifier } from '@dnd-kit/core';
import { getEventCoordinates } from '@dnd-kit/utilities';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
const selectZoom = createSelector(
[stateSelector, activeTabNameSelector],
({ nodes }, activeTabName) =>
activeTabName === 'nodes' ? nodes.viewport.zoom : 1
);
/**
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
*/
export const useScaledModifer = () => {
const zoom = useAppSelector(selectZoom);
const modifier: Modifier = useCallback(
({ activatorEvent, draggingNodeRect, transform }) => {
if (draggingNodeRect && activatorEvent) {
const activatorCoordinates = getEventCoordinates(activatorEvent);
if (!activatorCoordinates) {
return transform;
}
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
const x = transform.x + offsetX - draggingNodeRect.width / 2;
const y = transform.y + offsetY - draggingNodeRect.height / 2;
const scaleX = transform.scaleX * zoom;
const scaleY = transform.scaleY * zoom;
return {
x,
y,
scaleX,
scaleY,
};
}
return transform;
},
[zoom]
);
return modifier;
};

View File

@ -3,7 +3,6 @@ import {
Active, Active,
Collision, Collision,
DndContextProps, DndContextProps,
DndContext as OriginalDndContext,
Over, Over,
Translate, Translate,
UseDraggableArguments, UseDraggableArguments,
@ -11,6 +10,10 @@ import {
useDraggable as useOriginalDraggable, useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable, useDroppable as useOriginalDroppable,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import {
InputFieldTemplate,
InputFieldValue,
} from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
type BaseDropData = { type BaseDropData = {
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD'; actionType: 'REMOVE_FROM_BOARD';
}; };
export type AddFieldToLinearViewDropData = BaseDropData & {
actionType: 'ADD_FIELD_TO_LINEAR';
};
export type TypesafeDroppableData = export type TypesafeDroppableData =
| CurrentImageDropData | CurrentImageDropData
| InitialImageDropData | InitialImageDropData
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
| AddToBatchDropData | AddToBatchDropData
| NodesMultiImageDropData | NodesMultiImageDropData
| AddToBoardDropData | AddToBoardDropData
| RemoveFromBoardDropData; | RemoveFromBoardDropData
| AddFieldToLinearViewDropData;
type BaseDragData = { type BaseDragData = {
id: string; id: string;
}; };
export type NodeFieldDraggableData = BaseDragData & {
payloadType: 'NODE_FIELD';
payload: {
nodeId: string;
field: InputFieldValue;
fieldTemplate: InputFieldTemplate;
};
};
export type ImageDraggableData = BaseDragData & { export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO'; payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO }; payload: { imageDTO: ImageDTO };
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
payload: { imageDTOs: ImageDTO[] }; payload: { imageDTOs: ImageDTO[] };
}; };
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData; export type TypesafeDraggableData =
| NodeFieldDraggableData
| ImageDraggableData
| ImageDTOsDraggableData;
interface UseDroppableTypesafeArguments export interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> { extends Omit<UseDroppableArguments, 'data'> {
data?: TypesafeDroppableData; data?: TypesafeDroppableData;
} }
type UseDroppableTypesafeReturnValue = Omit< export type UseDroppableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDroppable>, ReturnType<typeof useOriginalDroppable>,
'active' | 'over' 'active' | 'over'
> & { > & {
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
over: TypesafeOver | null; over: TypesafeOver | null;
}; };
export function useDroppable(props: UseDroppableTypesafeArguments) { export interface UseDraggableTypesafeArguments
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
}
interface UseDraggableTypesafeArguments
extends Omit<UseDraggableArguments, 'data'> { extends Omit<UseDraggableArguments, 'data'> {
data?: TypesafeDraggableData; data?: TypesafeDraggableData;
} }
type UseDraggableTypesafeReturnValue = Omit< export type UseDraggableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDraggable>, ReturnType<typeof useOriginalDraggable>,
'active' | 'over' 'active' | 'over'
> & { > & {
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
over: TypesafeOver | null; over: TypesafeOver | null;
}; };
export function useDraggable(props: UseDraggableTypesafeArguments) { export interface TypesafeActive extends Omit<Active, 'data'> {
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
}
interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>; data: React.MutableRefObject<TypesafeDraggableData | undefined>;
} }
interface TypesafeOver extends Omit<Over, 'data'> { export interface TypesafeOver extends Omit<Over, 'data'> {
data: React.MutableRefObject<TypesafeDroppableData | undefined>; data: React.MutableRefObject<TypesafeDroppableData | undefined>;
} }
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};
interface DragEvent { interface DragEvent {
activatorEvent: Event; activatorEvent: Event;
active: TypesafeActive; active: TypesafeActive;
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
onDragEnd?(event: DragEndEvent): void; onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void; onDragCancel?(event: DragCancelEvent): void;
} }
export function DndContext(props: DndContextTypesafeProps) {
return <OriginalDndContext {...props} />;
}

View File

@ -0,0 +1,87 @@
import { TypesafeActive, TypesafeDroppableData } from '../types';
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'ADD_FIELD_TO_LINEAR':
return payloadType === 'NODE_FIELD';
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
default:
return false;
}
};

View File

@ -11,7 +11,6 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
import AutoAddIcon from '../AutoAddIcon'; import AutoAddIcon from '../AutoAddIcon';
import BoardContextMenu from '../BoardContextMenu'; import BoardContextMenu from '../BoardContextMenu';
import { AddToBoardDropData } from 'features/dnd/types';
interface GalleryBoardProps { interface GalleryBoardProps {
board: BoardDTO; board: BoardDTO;

View File

@ -1,7 +1,7 @@
import { As, Badge, Flex } from '@chakra-ui/react'; import { As, Badge, Flex } from '@chakra-ui/react';
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { TypesafeDroppableData } from 'features/dnd/types';
import { BoardId } from 'features/gallery/store/types'; import { BoardId } from 'features/gallery/store/types';
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import BoardContextMenu from '../BoardContextMenu'; import BoardContextMenu from '../BoardContextMenu';

View File

@ -1,15 +1,15 @@
import { Box, Flex, Image, Text } from '@chakra-ui/react'; import { Box, Flex, Image, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import InvokeAILogoImage from 'assets/images/logo.png'; import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable'; import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay'; import SelectionOverlay from 'common/components/SelectionOverlay';
import { RemoveFromBoardDropData } from 'features/dnd/types';
import { import {
boardIdSelected,
autoAddBoardIdChanged, autoAddBoardIdChanged,
boardIdSelected,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName'; import { useBoardName } from 'services/api/hooks/useBoardName';

View File

@ -1,14 +1,14 @@
import { Box, Flex, Image } from '@chakra-ui/react'; import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage'; import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';

View File

@ -11,7 +11,6 @@ import {
autoAssignBoardOnClickChanged, autoAssignBoardOnClickChanged,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
shouldAutoSwitchChanged, shouldAutoSwitchChanged,
shouldShowDeleteButtonChanged,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { ChangeEvent, useCallback } from 'react'; import { ChangeEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -26,14 +25,12 @@ const selector = createSelector(
galleryImageMinimumWidth, galleryImageMinimumWidth,
shouldAutoSwitch, shouldAutoSwitch,
autoAssignBoardOnClick, autoAssignBoardOnClick,
shouldShowDeleteButton,
} = state.gallery; } = state.gallery;
return { return {
galleryImageMinimumWidth, galleryImageMinimumWidth,
shouldAutoSwitch, shouldAutoSwitch,
autoAssignBoardOnClick, autoAssignBoardOnClick,
shouldShowDeleteButton,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -43,12 +40,8 @@ const GallerySettingsPopover = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } =
galleryImageMinimumWidth, useAppSelector(selector);
shouldAutoSwitch,
autoAssignBoardOnClick,
shouldShowDeleteButton,
} = useAppSelector(selector);
const handleChangeGalleryImageMinimumWidth = useCallback( const handleChangeGalleryImageMinimumWidth = useCallback(
(v: number) => { (v: number) => {
@ -68,13 +61,6 @@ const GallerySettingsPopover = () => {
[dispatch] [dispatch]
); );
const handleChangeShowDeleteButton = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowDeleteButtonChanged(e.target.checked));
},
[dispatch]
);
return ( return (
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
@ -90,7 +76,7 @@ const GallerySettingsPopover = () => {
<IAISlider <IAISlider
value={galleryImageMinimumWidth} value={galleryImageMinimumWidth}
onChange={handleChangeGalleryImageMinimumWidth} onChange={handleChangeGalleryImageMinimumWidth}
min={32} min={45}
max={256} max={256}
hideTooltip={true} hideTooltip={true}
label={t('gallery.galleryImageSize')} label={t('gallery.galleryImageSize')}
@ -102,11 +88,6 @@ const GallerySettingsPopover = () => {
isChecked={shouldAutoSwitch} isChecked={shouldAutoSwitch}
onChange={handleChangeAutoSwitch} onChange={handleChangeAutoSwitch}
/> />
<IAISwitch
label="Show Delete Button"
isChecked={shouldShowDeleteButton}
onChange={handleChangeShowDeleteButton}
/>
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('gallery.autoAssignBoardOnClick')} label={t('gallery.autoAssignBoardOnClick')}
isChecked={autoAssignBoardOnClick} isChecked={autoAssignBoardOnClick}

View File

@ -1,5 +1,8 @@
import { MenuList } from '@chakra-ui/react'; import { MenuList } from '@chakra-ui/react';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; import {
IAIContextMenu,
IAIContextMenuProps,
} from 'common/components/IAIContextMenu';
import { MouseEvent, memo, useCallback } from 'react'; import { MouseEvent, memo, useCallback } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu'; import { menuListMotionProps } from 'theme/components/menu';
@ -12,7 +15,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
type Props = { type Props = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children']; children: IAIContextMenuProps<HTMLDivElement>['children'];
}; };
const selector = createSelector( const selector = createSelector(
@ -33,7 +36,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}, []); }, []);
return ( return (
<ContextMenu<HTMLDivElement> <IAIContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
menuButtonProps={{ menuButtonProps={{
bg: 'transparent', bg: 'transparent',
@ -68,7 +71,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
}} }}
> >
{children} {children}
</ContextMenu> </IAIContextMenu>
); );
}; };

View File

@ -5,13 +5,21 @@ import {
isModalOpenChanged, isModalOpenChanged,
} from 'features/changeBoardModal/store/slice'; } from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useCallback } from 'react'; import { useCallback, useMemo } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa'; import { FaFolder, FaTrash } from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';
import {
useStarImagesMutation,
useUnstarImagesMutation,
} from '../../../../services/api/endpoints/images';
const MultipleSelectionMenuItems = () => { const MultipleSelectionMenuItems = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selection = useAppSelector((state) => state.gallery.selection); const selection = useAppSelector((state) => state.gallery.selection);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const handleChangeBoard = useCallback(() => { const handleChangeBoard = useCallback(() => {
dispatch(imagesToChangeSelected(selection)); dispatch(imagesToChangeSelected(selection));
dispatch(isModalOpenChanged(true)); dispatch(isModalOpenChanged(true));
@ -21,8 +29,37 @@ const MultipleSelectionMenuItems = () => {
dispatch(imagesToDeleteSelected(selection)); dispatch(imagesToDeleteSelected(selection));
}, [dispatch, selection]); }, [dispatch, selection]);
const handleStarSelection = useCallback(() => {
starImages({ imageDTOs: selection });
}, [starImages, selection]);
const handleUnstarSelection = useCallback(() => {
unstarImages({ imageDTOs: selection });
}, [unstarImages, selection]);
const areAllStarred = useMemo(() => {
return selection.every((img) => img.starred);
}, [selection]);
const areAllUnstarred = useMemo(() => {
return selection.every((img) => !img.starred);
}, [selection]);
return ( return (
<> <>
{areAllStarred && (
<MenuItem
icon={<MdStarBorder />}
onClickCapture={handleUnstarSelection}
>
Unstar All
</MenuItem>
)}
{(areAllUnstarred || (!areAllStarred && !areAllUnstarred)) && (
<MenuItem icon={<MdStar />} onClickCapture={handleStarSelection}>
Star All
</MenuItem>
)}
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}> <MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
Change Board Change Board
</MenuItem> </MenuItem>

View File

@ -29,10 +29,15 @@ import {
FaShare, FaShare,
FaTrash, FaTrash,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import {
useGetImageMetadataQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce'; import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { MdStar, MdStarBorder } from 'react-icons/md';
type SingleSelectionMenuItemsProps = { type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO; imageDTO: ImageDTO;
@ -59,6 +64,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
: debouncedMetadataQueryArg ?? skipToken : debouncedMetadataQueryArg ?? skipToken
); );
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const { isClipboardAPIAvailable, copyImageToClipboard } = const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard(); useCopyImageToClipboard();
@ -127,6 +135,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
copyImageToClipboard(imageDTO.image_url); copyImageToClipboard(imageDTO.image_url);
}, [copyImageToClipboard, imageDTO.image_url]); }, [copyImageToClipboard, imageDTO.image_url]);
const handleStarImage = useCallback(() => {
if (imageDTO) starImages({ imageDTOs: [imageDTO] });
}, [starImages, imageDTO]);
const handleUnstarImage = useCallback(() => {
if (imageDTO) unstarImages({ imageDTOs: [imageDTO] });
}, [unstarImages, imageDTO]);
return ( return (
<> <>
<MenuItem <MenuItem
@ -196,6 +212,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}> <MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
Change Board Change Board
</MenuItem> </MenuItem>
{imageDTO.starred ? (
<MenuItem icon={<MdStar />} onClickCapture={handleUnstarImage}>
Unstar Image
</MenuItem>
) : (
<MenuItem icon={<MdStarBorder />} onClickCapture={handleStarImage}>
Star Image
</MenuItem>
)}
<MenuItem <MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }} sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />} icon={<FaTrash />}

View File

@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
return ( return (
<VStack <VStack
layerStyle="first"
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
h: 'full', h: 'full',
w: 'full', w: 'full',
borderRadius: 'base', borderRadius: 'base',
p: 2,
}} }}
> >
<Box sx={{ w: 'full' }}> <Box sx={{ w: 'full' }}>

View File

@ -1,17 +1,23 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { import {
ImageDTOsDraggableData, ImageDTOsDraggableData,
ImageDraggableData, ImageDraggableData,
TypesafeDraggableData, TypesafeDraggableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'features/dnd/types';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts'; import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { MdStar, MdStarBorder } from 'react-icons/md';
import {
useGetImageDTOQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import IAIDndImageIcon from '../../../../common/components/IAIDndImageIcon';
interface HoverableImageProps { interface HoverableImageProps {
imageName: string; imageName: string;
@ -21,9 +27,7 @@ const GalleryImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { imageName } = props; const { imageName } = props;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName); const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
const shouldShowDeleteButton = useAppSelector( const shift = useAppSelector((state) => state.hotkeys.shift);
(state) => state.gallery.shouldShowDeleteButton
);
const { handleClick, isSelected, selection, selectionCount } = const { handleClick, isSelected, selection, selectionCount } =
useMultiselect(imageDTO); useMultiselect(imageDTO);
@ -59,6 +63,35 @@ const GalleryImage = (props: HoverableImageProps) => {
} }
}, [imageDTO, selection, selectionCount]); }, [imageDTO, selection, selectionCount]);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const toggleStarredState = useCallback(() => {
if (imageDTO) {
if (imageDTO.starred) {
unstarImages({ imageDTOs: [imageDTO] });
}
if (!imageDTO.starred) {
starImages({ imageDTOs: [imageDTO] });
}
}
}, [starImages, unstarImages, imageDTO]);
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
const handleMouseOut = useCallback(() => {
setIsHovered(false);
}, []);
const starIcon = useMemo(() => {
if (imageDTO?.starred) return <MdStar size="20" />;
if (!imageDTO?.starred && isHovered) return <MdStarBorder size="20" />;
}, [imageDTO?.starred, isHovered]);
if (!imageDTO) { if (!imageDTO) {
return <IAIFillSkeleton />; return <IAIFillSkeleton />;
} }
@ -80,16 +113,34 @@ const GalleryImage = (props: HoverableImageProps) => {
draggableData={draggableData} draggableData={draggableData}
isSelected={isSelected} isSelected={isSelected}
minSize={0} minSize={0}
onClickReset={handleDelete}
imageSx={{ w: 'full', h: 'full' }} imageSx={{ w: 'full', h: 'full' }}
isDropDisabled={true} isDropDisabled={true}
isUploadDisabled={true} isUploadDisabled={true}
thumbnail={true} thumbnail={true}
withHoverOverlay withHoverOverlay
resetIcon={<FaTrash />} onMouseOver={handleMouseOver}
resetTooltip="Delete image" onMouseOut={handleMouseOut}
withResetIcon={shouldShowDeleteButton} // removed bc it's too easy to accidentally delete images >
<>
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={imageDTO.starred ? 'Unstar' : 'Star'}
/> />
{isHovered && shift && (
<IAIDndImageIcon
onClick={handleDelete}
icon={<FaTrash />}
tooltip="Delete"
styleOverrides={{
bottom: 2,
top: 'auto',
}}
/>
)}
</>
</IAIDndImage>
</Flex> </Flex>
</Box> </Box>
); );

View File

@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
options: { options: {
scrollbars: { scrollbars: {
visibility: 'auto', visibility: 'auto',
autoHide: 'leave', autoHide: 'scroll',
autoHideDelay: 1300, autoHideDelay: 1300,
theme: 'os-theme-dark', theme: 'os-theme-dark',
}, },

Some files were not shown because too many files have changed in this diff Show More