mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into refactor/rename-get-logger
This commit is contained in:
commit
fc9b4539a3
@ -5,7 +5,7 @@ from PIL import Image
|
|||||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
@ -19,6 +19,7 @@ from ..dependencies import ApiDependencies
|
|||||||
|
|
||||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||||
|
|
||||||
|
|
||||||
# images are immutable; set a high max-age
|
# images are immutable; set a high max-age
|
||||||
IMAGE_MAX_AGE = 31536000
|
IMAGE_MAX_AGE = 31536000
|
||||||
|
|
||||||
@ -286,3 +287,41 @@ async def delete_images_from_list(
|
|||||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||||
|
|
||||||
|
|
||||||
|
class ImagesUpdatedFromListResult(BaseModel):
|
||||||
|
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||||
|
async def star_images_in_list(
|
||||||
|
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||||
|
) -> ImagesUpdatedFromListResult:
|
||||||
|
try:
|
||||||
|
updated_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||||
|
updated_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||||
|
async def unstar_images_in_list(
|
||||||
|
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||||
|
) -> ImagesUpdatedFromListResult:
|
||||||
|
try:
|
||||||
|
updated_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||||
|
updated_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||||
|
@ -38,7 +38,7 @@ import mimetypes
|
|||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -134,6 +134,11 @@ def custom_openapi():
|
|||||||
# This could break in some cases, figure out a better way to do it
|
# This could break in some cases, figure out a better way to do it
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
output_type_titles[schema_key] = output_schema["title"]
|
||||||
|
|
||||||
|
# Add Node Editor UI helper schemas
|
||||||
|
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||||
|
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
invoker_name = invoker.__name__
|
invoker_name = invoker.__name__
|
||||||
|
@ -3,15 +3,366 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
AbstractSet,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.fields import Undefined
|
||||||
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDescriptions:
|
||||||
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
scheduler = "Scheduler to use during inference"
|
||||||
|
positive_cond = "Positive conditioning tensor"
|
||||||
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
noise = "Noise tensor"
|
||||||
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
vae = "VAE"
|
||||||
|
cond = "Conditioning tensor"
|
||||||
|
controlnet_model = "ControlNet model to load"
|
||||||
|
vae_model = "VAE model to load"
|
||||||
|
lora_model = "LoRA model to load"
|
||||||
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
raw_prompt = "Raw prompt text (no parsing)"
|
||||||
|
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||||
|
skipped_layers = "Number of layers to skip in text encoder"
|
||||||
|
seed = "Seed for random number generation"
|
||||||
|
steps = "Number of steps to run"
|
||||||
|
width = "Width of output (px)"
|
||||||
|
height = "Height of output (px)"
|
||||||
|
control = "ControlNet(s) to apply"
|
||||||
|
denoised_latents = "Denoised latents tensor"
|
||||||
|
latents = "Latents tensor"
|
||||||
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
core_metadata = "Optional core metadata to be written to image"
|
||||||
|
interp_mode = "Interpolation mode"
|
||||||
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
|
precision = "Precision to use"
|
||||||
|
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||||
|
detect_res = "Pixel resolution for detection"
|
||||||
|
image_res = "Pixel resolution for output image"
|
||||||
|
safe_mode = "Whether or not to use safe mode"
|
||||||
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
|
scale_factor = "The factor by which to scale"
|
||||||
|
num_1 = "The first number"
|
||||||
|
num_2 = "The second number"
|
||||||
|
mask = "The mask to use for the operation"
|
||||||
|
|
||||||
|
|
||||||
|
class Input(str, Enum):
|
||||||
|
"""
|
||||||
|
The type of input a field accepts.
|
||||||
|
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||||
|
are instantiated.
|
||||||
|
- `Input.Connection`: The field must have its value provided by a connection.
|
||||||
|
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Connection = "connection"
|
||||||
|
Direct = "direct"
|
||||||
|
Any = "any"
|
||||||
|
|
||||||
|
|
||||||
|
class UIType(str, Enum):
|
||||||
|
"""
|
||||||
|
Type hints for the UI.
|
||||||
|
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||||
|
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||||
|
on adding a new field type, which involves client-side changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# region Primitives
|
||||||
|
Integer = "integer"
|
||||||
|
Float = "float"
|
||||||
|
Boolean = "boolean"
|
||||||
|
String = "string"
|
||||||
|
Array = "array"
|
||||||
|
Image = "ImageField"
|
||||||
|
Latents = "LatentsField"
|
||||||
|
Conditioning = "ConditioningField"
|
||||||
|
Control = "ControlField"
|
||||||
|
Color = "ColorField"
|
||||||
|
ImageCollection = "ImageCollection"
|
||||||
|
ConditioningCollection = "ConditioningCollection"
|
||||||
|
ColorCollection = "ColorCollection"
|
||||||
|
LatentsCollection = "LatentsCollection"
|
||||||
|
IntegerCollection = "IntegerCollection"
|
||||||
|
FloatCollection = "FloatCollection"
|
||||||
|
StringCollection = "StringCollection"
|
||||||
|
BooleanCollection = "BooleanCollection"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Models
|
||||||
|
MainModel = "MainModelField"
|
||||||
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
|
ONNXModel = "ONNXModelField"
|
||||||
|
VaeModel = "VaeModelField"
|
||||||
|
LoRAModel = "LoRAModelField"
|
||||||
|
ControlNetModel = "ControlNetModelField"
|
||||||
|
UNet = "UNetField"
|
||||||
|
Vae = "VaeField"
|
||||||
|
CLIP = "ClipField"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Iterate/Collect
|
||||||
|
Collection = "Collection"
|
||||||
|
CollectionItem = "CollectionItem"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Misc
|
||||||
|
FilePath = "FilePath"
|
||||||
|
Enum = "enum"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
class UIComponent(str, Enum):
|
||||||
|
"""
|
||||||
|
The type of UI component to use for a field, used to override the default components, which are \
|
||||||
|
inferred from the field type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
None_ = "none"
|
||||||
|
Textarea = "textarea"
|
||||||
|
Slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
|
class _InputField(BaseModel):
|
||||||
|
"""
|
||||||
|
*DO NOT USE*
|
||||||
|
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||||
|
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||||
|
purpose in the backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input: Input
|
||||||
|
ui_hidden: bool
|
||||||
|
ui_type: Optional[UIType]
|
||||||
|
ui_component: Optional[UIComponent]
|
||||||
|
|
||||||
|
|
||||||
|
class _OutputField(BaseModel):
|
||||||
|
"""
|
||||||
|
*DO NOT USE*
|
||||||
|
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||||
|
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||||
|
purpose in the backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ui_hidden: bool
|
||||||
|
ui_type: Optional[UIType]
|
||||||
|
|
||||||
|
|
||||||
|
def InputField(
|
||||||
|
*args: Any,
|
||||||
|
default: Any = Undefined,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
allow_inf_nan: Optional[bool] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
input: Input = Input.Any,
|
||||||
|
ui_type: Optional[UIType] = None,
|
||||||
|
ui_component: Optional[UIComponent] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an input field for an invocation.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||||
|
`Input.Direct` means a value must be provided on instantiation. \
|
||||||
|
`Input.Connection` means the value must be provided by a connection. \
|
||||||
|
`Input.Any` means either will do.
|
||||||
|
|
||||||
|
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||||
|
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||||
|
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||||
|
For this case, you could provide `UIComponent.Textarea`.
|
||||||
|
|
||||||
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||||
|
"""
|
||||||
|
return Field(
|
||||||
|
*args,
|
||||||
|
default=default,
|
||||||
|
default_factory=default_factory,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
const=const,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
allow_inf_nan=allow_inf_nan,
|
||||||
|
max_digits=max_digits,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
min_items=min_items,
|
||||||
|
max_items=max_items,
|
||||||
|
unique_items=unique_items,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
allow_mutation=allow_mutation,
|
||||||
|
regex=regex,
|
||||||
|
discriminator=discriminator,
|
||||||
|
repr=repr,
|
||||||
|
input=input,
|
||||||
|
ui_type=ui_type,
|
||||||
|
ui_component=ui_component,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def OutputField(
|
||||||
|
*args: Any,
|
||||||
|
default: Any = Undefined,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
allow_inf_nan: Optional[bool] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
ui_type: Optional[UIType] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an output field for an invocation output.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||||
|
"""
|
||||||
|
return Field(
|
||||||
|
*args,
|
||||||
|
default=default,
|
||||||
|
default_factory=default_factory,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
const=const,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
allow_inf_nan=allow_inf_nan,
|
||||||
|
max_digits=max_digits,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
min_items=min_items,
|
||||||
|
max_items=max_items,
|
||||||
|
unique_items=unique_items,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
allow_mutation=allow_mutation,
|
||||||
|
regex=regex,
|
||||||
|
discriminator=discriminator,
|
||||||
|
repr=repr,
|
||||||
|
ui_type=ui_type,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UIConfigBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Provides additional node configuration to the UI.
|
||||||
|
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
||||||
|
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||||
|
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
services: InvocationServices
|
services: InvocationServices
|
||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
@ -39,6 +390,20 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return tuple(subclasses)
|
return tuple(subclasses)
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredConnectionException(Exception):
|
||||||
|
"""Raised when an field which requires a connection did not receive a value."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, field_name: str):
|
||||||
|
super().__init__(f"Node {node_id} missing connections for field {field_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class MissingInputException(Exception):
|
||||||
|
"""Raised when an field which requires some input, but did not receive a value."""
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, field_name: str):
|
||||||
|
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
|
||||||
|
|
||||||
|
|
||||||
class BaseInvocation(ABC, BaseModel):
|
class BaseInvocation(ABC, BaseModel):
|
||||||
"""A node to process inputs and produce outputs.
|
"""A node to process inputs and produce outputs.
|
||||||
May use dependency injection in __init__ to receive providers.
|
May use dependency injection in __init__ to receive providers.
|
||||||
@ -76,70 +441,81 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def get_output_type(cls):
|
def get_output_type(cls):
|
||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
@staticmethod
|
||||||
|
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
uiconfig = getattr(model_class, "UIConfig", None)
|
||||||
|
if uiconfig and hasattr(uiconfig, "title"):
|
||||||
|
schema["title"] = uiconfig.title
|
||||||
|
if uiconfig and hasattr(uiconfig, "tags"):
|
||||||
|
schema["tags"] = uiconfig.tags
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# fmt: off
|
def __init__(self, **data):
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
# nodes may have required fields, that can accept input from connections
|
||||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
# on instantiation of the model, we need to exclude these from validation
|
||||||
# fmt: on
|
restore = dict()
|
||||||
|
try:
|
||||||
|
field_names = list(self.__fields__.keys())
|
||||||
|
for field_name in field_names:
|
||||||
|
# if the field is required and may get its value from a connection, exclude it from validation
|
||||||
|
field = self.__fields__[field_name]
|
||||||
|
_input = field.field_info.extra.get("input", None)
|
||||||
|
if _input in [Input.Connection, Input.Any] and field.required:
|
||||||
|
if field_name not in data:
|
||||||
|
restore[field_name] = self.__fields__.pop(field_name)
|
||||||
|
# instantiate the node, which will validate the data
|
||||||
|
super().__init__(**data)
|
||||||
|
finally:
|
||||||
|
# restore the removed fields
|
||||||
|
for field_name, field in restore.items():
|
||||||
|
self.__fields__[field_name] = field
|
||||||
|
|
||||||
|
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
|
for field_name, field in self.__fields__.items():
|
||||||
|
_input = field.field_info.extra.get("input", None)
|
||||||
|
if field.required and not hasattr(self, field_name):
|
||||||
|
if _input == Input.Connection:
|
||||||
|
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||||
|
elif _input == Input.Any:
|
||||||
|
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||||
|
return self.invoke(context)
|
||||||
|
|
||||||
|
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
||||||
|
is_intermediate: bool = InputField(
|
||||||
|
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||||
|
)
|
||||||
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: figure out a better way to provide these hints
|
T = TypeVar("T", bound=BaseInvocation)
|
||||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
|
||||||
class UIConfig(TypedDict, total=False):
|
|
||||||
type_hints: Dict[
|
|
||||||
str,
|
|
||||||
Literal[
|
|
||||||
"integer",
|
|
||||||
"float",
|
|
||||||
"boolean",
|
|
||||||
"string",
|
|
||||||
"enum",
|
|
||||||
"image",
|
|
||||||
"latents",
|
|
||||||
"model",
|
|
||||||
"control",
|
|
||||||
"image_collection",
|
|
||||||
"vae_model",
|
|
||||||
"lora_model",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
tags: List[str]
|
|
||||||
title: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomisedSchemaExtra(TypedDict):
|
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
ui: UIConfig
|
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
||||||
|
|
||||||
|
def wrapper(cls: Type[T]) -> Type[T]:
|
||||||
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
|
cls.UIConfig.title = title
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class InvocationConfig(BaseConfig):
|
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
||||||
|
|
||||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
def wrapper(cls: Type[T]) -> Type[T]:
|
||||||
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
|
cls.UIConfig.tags = list(tags)
|
||||||
|
return cls
|
||||||
|
|
||||||
`tags`
|
return wrapper
|
||||||
- A list of strings, used to categorise invocations.
|
|
||||||
|
|
||||||
`type_hints`
|
|
||||||
- A dict of field types which override the types in the invocation definition.
|
|
||||||
- Each key should be the name of one of the invocation's fields.
|
|
||||||
- Each value should be one of the valid types:
|
|
||||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
|
||||||
|
|
||||||
```python
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["stable-diffusion", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"initial_image": "image",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
schema_extra: CustomisedSchemaExtra
|
|
||||||
|
@ -3,58 +3,25 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field, validator
|
from pydantic import validator
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
|
||||||
"""A collection of integers"""
|
|
||||||
|
|
||||||
type: Literal["int_collection"] = "int_collection"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[int] = Field(default=[], description="The int collection")
|
|
||||||
|
|
||||||
|
|
||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
|
||||||
"""A collection of floats"""
|
|
||||||
|
|
||||||
type: Literal["float_collection"] = "float_collection"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[float] = Field(default=[], description="The float collection")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
|
||||||
"""A collection of images"""
|
|
||||||
|
|
||||||
type: Literal["image_collection"] = "image_collection"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "collection"]}
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Range")
|
||||||
|
@tags("collection", "integer", "range")
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
type: Literal["range"] = "range"
|
type: Literal["range"] = "range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: int = Field(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
stop: int = Field(default=10, description="The stop of the range")
|
stop: int = InputField(default=10, description="The stop of the range")
|
||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("stop")
|
@validator("stop")
|
||||||
def stop_gt_start(cls, v, values):
|
def stop_gt_start(cls, v, values):
|
||||||
@ -62,76 +29,44 @@ class RangeInvocation(BaseInvocation):
|
|||||||
raise ValueError("stop must be greater than start")
|
raise ValueError("stop must be greater than start")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Range of Size")
|
||||||
|
@tags("range", "integer", "size", "collection")
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
"""Creates a range from start to start + size with step"""
|
"""Creates a range from start to start + size with step"""
|
||||||
|
|
||||||
type: Literal["range_of_size"] = "range_of_size"
|
type: Literal["range_of_size"] = "range_of_size"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: int = Field(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
size: int = Field(default=1, description="The number of values")
|
size: int = InputField(default=1, description="The number of values")
|
||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
schema_extra = {
|
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Random Range")
|
||||||
|
@tags("range", "integer", "random", "collection")
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
type: Literal["random_range"] = "random_range"
|
type: Literal["random_range"] = "random_range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = InputField(default=1, description="The number of values to generate")
|
||||||
seed: int = Field(
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed for the RNG (omit for random)",
|
description="The seed for the RNG (omit for random)",
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
|
||||||
rng = np.random.default_rng(self.seed)
|
rng = np.random.default_rng(self.seed)
|
||||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
|
||||||
"""Load a collection of images and provide it as output."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_collection"] = "image_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
images: list[ImageField] = Field(
|
|
||||||
default=[], description="The image collection to load"
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
|
||||||
return ImageCollectionOutput(collection=self.images)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"type_hints": {
|
|
||||||
"title": "Image Collection",
|
|
||||||
"images": "image_collection",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
@ -1,32 +1,35 @@
|
|||||||
from typing import Literal, Optional, Union, List, Annotated
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from typing import List, Literal, Union
|
||||||
from .model import ClipField
|
|
||||||
|
|
||||||
from ...backend.util.devices import torch_dtype
|
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from ...backend.util.devices import torch_dtype
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||||
from ...backend.model_management import ModelType
|
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...backend.model_management import ModelPatcher, ModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
|
from ...backend.util.devices import torch_dtype
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -41,32 +44,26 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
class CompelOutput(BaseInvocationOutput):
|
@title("Compel Prompt")
|
||||||
"""Compel parser output"""
|
@tags("prompt", "compel")
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["compel_output"] = "compel_output"
|
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
type: Literal["compel"] = "compel"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
# Schema customisation
|
ui_component=UIComponent.Textarea,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
clip: ClipField = InputField(
|
||||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
title="CLIP",
|
||||||
}
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -149,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
return CompelOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
),
|
),
|
||||||
@ -270,30 +267,26 @@ class SDXLPromptInvocationBase:
|
|||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Compel Prompt")
|
||||||
|
@tags("sdxl", "compel", "prompt")
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
style: str = Field(default="", description="Style prompt")
|
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
crop_left: int = Field(0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
target_width: int = Field(1024, description="")
|
target_width: int = InputField(default=1024, description="")
|
||||||
target_height: int = Field(1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||||
)
|
)
|
||||||
@ -326,38 +319,32 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
return CompelOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Refiner Compel Prompt")
|
||||||
|
@tags("sdxl", "compel", "prompt")
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = InputField(
|
||||||
original_width: int = Field(1024, description="")
|
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||||
original_height: int = Field(1024, description="")
|
) # TODO: ?
|
||||||
crop_top: int = Field(0, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
crop_left: int = Field(0, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
aesthetic_score: float = Field(6.0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
crop_left: int = InputField(default=0, description="")
|
||||||
|
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||||
# Schema customisation
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Refiner Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||||
|
|
||||||
@ -380,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
return CompelOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
),
|
),
|
||||||
@ -391,21 +378,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
|||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
|
@title("CLIP Skip")
|
||||||
|
@tags("clipskip", "clip", "skip")
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
type: Literal["clip_skip"] = "clip_skip"
|
type: Literal["clip_skip"] = "clip_skip"
|
||||||
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
self.clip.skipped_layers += self.skipped_layers
|
self.clip.skipped_layers += self.skipped_layers
|
||||||
|
@ -26,79 +26,31 @@ from controlnet_aux.util import HWC3, ade_palette
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
from ..models.image import ImageOutput, PILInvocationConfig
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
CONTROLNET_DEFAULT_MODELS = [
|
|
||||||
###########################################
|
|
||||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
|
||||||
##############################################
|
|
||||||
"lllyasviel/sd-controlnet-canny",
|
|
||||||
"lllyasviel/sd-controlnet-depth",
|
|
||||||
"lllyasviel/sd-controlnet-hed",
|
|
||||||
"lllyasviel/sd-controlnet-seg",
|
|
||||||
"lllyasviel/sd-controlnet-openpose",
|
|
||||||
"lllyasviel/sd-controlnet-scribble",
|
|
||||||
"lllyasviel/sd-controlnet-normal",
|
|
||||||
"lllyasviel/sd-controlnet-mlsd",
|
|
||||||
#############################################
|
|
||||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
|
||||||
#############################################
|
|
||||||
"lllyasviel/control_v11p_sd15_canny",
|
|
||||||
"lllyasviel/control_v11p_sd15_openpose",
|
|
||||||
"lllyasviel/control_v11p_sd15_seg",
|
|
||||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
|
||||||
"lllyasviel/control_v11f1p_sd15_depth",
|
|
||||||
"lllyasviel/control_v11p_sd15_normalbae",
|
|
||||||
"lllyasviel/control_v11p_sd15_scribble",
|
|
||||||
"lllyasviel/control_v11p_sd15_mlsd",
|
|
||||||
"lllyasviel/control_v11p_sd15_softedge",
|
|
||||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
|
||||||
"lllyasviel/control_v11p_sd15_lineart",
|
|
||||||
"lllyasviel/control_v11p_sd15_inpaint",
|
|
||||||
# "lllyasviel/control_v11u_sd15_tile",
|
|
||||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
|
||||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
|
||||||
"lllyasviel/control_v11e_sd15_shuffle",
|
|
||||||
"lllyasviel/control_v11e_sd15_ip2p",
|
|
||||||
"lllyasviel/control_v11f1e_sd15_tile",
|
|
||||||
#################################################
|
|
||||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
|
||||||
##################################################
|
|
||||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-canny-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-depth-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-hed-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-color-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
|
||||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
|
||||||
##############################################
|
|
||||||
# ControlNetMediaPipeface, ControlNet v1.1
|
|
||||||
##############################################
|
|
||||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
|
||||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
|
||||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
|
||||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
|
||||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
|
||||||
]
|
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
|
||||||
CONTROLNET_RESIZE_VALUES = Literal[
|
CONTROLNET_RESIZE_VALUES = Literal[
|
||||||
tuple(
|
"just_resize",
|
||||||
[
|
"crop_resize",
|
||||||
"just_resize",
|
"fill_resize",
|
||||||
"crop_resize",
|
"just_resize_simple",
|
||||||
"fill_resize",
|
|
||||||
"just_resize_simple",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +62,8 @@ class ControlNetModelField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -135,60 +86,39 @@ class ControlField(BaseModel):
|
|||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
|
||||||
"ui": {
|
|
||||||
"type_hints": {
|
|
||||||
"control_weight": "float",
|
|
||||||
"control_model": "controlnet_model",
|
|
||||||
# "control_weight": "number",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ControlOutput(BaseInvocationOutput):
|
class ControlOutput(BaseInvocationOutput):
|
||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The control info")
|
|
||||||
# fmt: on
|
# Outputs
|
||||||
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||||
|
|
||||||
|
|
||||||
|
@title("ControlNet")
|
||||||
|
@tags("controlnet")
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
|
||||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
|
||||||
description="control model used")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
|
||||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The control image")
|
||||||
"ui": {
|
control_model: ControlNetModelField = InputField(
|
||||||
"title": "ControlNet",
|
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||||
"tags": ["controlnet", "latents"],
|
)
|
||||||
"type_hints": {
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
"model": "model",
|
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||||
"control": "control",
|
)
|
||||||
# "cfg_scale": "float",
|
begin_step_percent: float = InputField(
|
||||||
"cfg_scale": "number",
|
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||||
"control_weight": "float",
|
)
|
||||||
},
|
end_step_percent: float = InputField(
|
||||||
},
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
}
|
)
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
return ControlOutput(
|
return ControlOutput(
|
||||||
@ -204,19 +134,13 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_processor"] = "image_processor"
|
type: Literal["image_processor"] = "image_processor"
|
||||||
# Inputs
|
|
||||||
image: ImageField = Field(default=None, description="The image to process")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The image to process")
|
||||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
@ -255,20 +179,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Canny Processor")
|
||||||
|
@tags("controlnet", "canny")
|
||||||
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
# Input
|
|
||||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
|
||||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Input
|
||||||
schema_extra = {
|
low_threshold: int = InputField(
|
||||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||||
}
|
)
|
||||||
|
high_threshold: int = InputField(
|
||||||
|
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||||
|
)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
canny_processor = CannyDetector()
|
canny_processor = CannyDetector()
|
||||||
@ -276,23 +200,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("HED (softedge) Processor")
|
||||||
|
@tags("controlnet", "hed", "softedge")
|
||||||
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
|
||||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
|
||||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -307,21 +227,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Lineart Processor")
|
||||||
|
@tags("controlnet", "lineart")
|
||||||
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -331,23 +247,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Lineart Anime Processor")
|
||||||
|
@tags("controlnet", "lineart", "anime")
|
||||||
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
"title": "Lineart Anime Processor",
|
|
||||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -359,21 +268,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Openpose Processor")
|
||||||
|
@tags("controlnet", "openpose", "pose")
|
||||||
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
# Inputs
|
|
||||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
}
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -386,22 +291,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Midas (Depth) Processor")
|
||||||
|
@tags("controlnet", "midas", "depth")
|
||||||
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
# Inputs
|
|
||||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
|
||||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
|
||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
|
||||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||||
}
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -415,20 +316,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Normal BAE Processor")
|
||||||
|
@tags("controlnet", "normal", "bae")
|
||||||
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -438,22 +335,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("MLSD Processor")
|
||||||
|
@tags("controlnet", "mlsd")
|
||||||
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
|
||||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -467,22 +360,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("PIDI Processor")
|
||||||
|
@tags("controlnet", "pidi")
|
||||||
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
|
||||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
}
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -496,26 +385,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Content Shuffle Processor")
|
||||||
|
@tags("controlnet", "contentshuffle")
|
||||||
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
|
||||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
|
||||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
"ui": {
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
"title": "Content Shuffle Processor",
|
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
},
|
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
@ -531,17 +413,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
|||||||
|
|
||||||
|
|
||||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Zoe (Depth) Processor")
|
||||||
|
@tags("controlnet", "zoe", "depth")
|
||||||
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -549,20 +426,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Mediapipe Face Processor")
|
||||||
|
@tags("controlnet", "mediapipe", "face")
|
||||||
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
# Inputs
|
|
||||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
|
||||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||||
@ -574,23 +447,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Leres (Depth) Processor")
|
||||||
|
@tags("controlnet", "leres", "depth")
|
||||||
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||||
# Inputs
|
|
||||||
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
|
||||||
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
|
||||||
boost: bool = Field(default=False, description="Whether to use boost mode")
|
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
|
||||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||||
}
|
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||||
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
@ -605,21 +474,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Tile Resample Processor")
|
||||||
# fmt: off
|
@tags("controlnet", "tile")
|
||||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
# Inputs
|
"""Tile resampler processor"""
|
||||||
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
|
||||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
# Inputs
|
||||||
"title": "Tile Resample Processor",
|
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||||
def tile_resample(
|
def tile_resample(
|
||||||
@ -648,20 +512,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
@title("Segment Anything Processor")
|
||||||
|
@tags("controlnet", "segmentanything")
|
||||||
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Segment Anything Processor",
|
|
||||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
|
@ -5,40 +5,22 @@ from typing import Literal
|
|||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
|
|
||||||
class CvInvocationConfig(BaseModel):
|
@title("OpenCV Inpaint")
|
||||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
@tags("opencv", "inpaint")
|
||||||
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["cv", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -1,60 +1,31 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
class LoadImageInvocation(BaseInvocation):
|
|
||||||
"""Load an image and provide it as output."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["load_image"] = "load_image"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Optional[ImageField] = Field(
|
|
||||||
default=None, description="The image to load"
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=self.image.image_name),
|
|
||||||
width=image.width,
|
|
||||||
height=image.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Show Image")
|
||||||
|
@tags("image")
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||||
|
|
||||||
|
# Metadata
|
||||||
type: Literal["show_image"] = "show_image"
|
type: Literal["show_image"] = "show_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
image: ImageField = InputField(description="The image to show")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -70,24 +41,20 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Crop Image")
|
||||||
|
@tags("image", "crop")
|
||||||
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_crop"] = "img_crop"
|
type: Literal["img_crop"] = "img_crop"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
image: ImageField = InputField(description="The image to crop")
|
||||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -111,24 +78,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Paste Image")
|
||||||
|
@tags("image", "paste")
|
||||||
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_paste"] = "img_paste"
|
type: Literal["img_paste"] = "img_paste"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
base_image: ImageField = InputField(description="The base image")
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
image: ImageField = InputField(description="The image to paste")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
mask: Optional[ImageField] = InputField(
|
||||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
default=None,
|
||||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
description="The mask to use when pasting",
|
||||||
# fmt: on
|
)
|
||||||
|
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
|
||||||
class Config(InvocationConfig):
|
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||||
@ -164,23 +130,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Mask from Alpha")
|
||||||
|
@tags("image", "mask")
|
||||||
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["tomask"] = "tomask"
|
type: Literal["tomask"] = "tomask"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
image: ImageField = InputField(description="The image to create the mask from")
|
||||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_mask = image.split()[-1]
|
image_mask = image.split()[-1]
|
||||||
@ -196,28 +158,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MaskOutput(
|
return ImageOutput(
|
||||||
mask=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Multiply Images")
|
||||||
|
@tags("image", "multiply")
|
||||||
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_mul"] = "img_mul"
|
type: Literal["img_mul"] = "img_mul"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
image1: ImageField = InputField(description="The first image to multiply")
|
||||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
image2: ImageField = InputField(description="The second image to multiply")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||||
@ -244,21 +202,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Extract Image Channel")
|
||||||
|
@tags("image", "channel")
|
||||||
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_chan"] = "img_chan"
|
type: Literal["img_chan"] = "img_chan"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
image: ImageField = InputField(description="The image to get the channel from")
|
||||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -284,21 +238,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Convert Image Mode")
|
||||||
|
@tags("image", "convert")
|
||||||
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_conv"] = "img_conv"
|
type: Literal["img_conv"] = "img_conv"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
image: ImageField = InputField(description="The image to convert")
|
||||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -321,22 +271,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Blur Image")
|
||||||
|
@tags("image", "blur")
|
||||||
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_blur"] = "img_blur"
|
type: Literal["img_blur"] = "img_blur"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
image: ImageField = InputField(description="The image to blur")
|
||||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
# Metadata
|
||||||
# fmt: on
|
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -382,23 +329,19 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Resize Image")
|
||||||
|
@tags("image", "resize")
|
||||||
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_resize"] = "img_resize"
|
type: Literal["img_resize"] = "img_resize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
image: ImageField = InputField(description="The image to resize")
|
||||||
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -426,22 +369,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Scale Image")
|
||||||
|
@tags("image", "scale")
|
||||||
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_scale"] = "img_scale"
|
type: Literal["img_scale"] = "img_scale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
image: ImageField = InputField(description="The image to scale")
|
||||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
scale_factor: float = InputField(
|
||||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
default=2.0,
|
||||||
# fmt: on
|
gt=0,
|
||||||
|
description="The factor by which to scale the image",
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -471,22 +414,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Lerp Image")
|
||||||
|
@tags("image", "lerp")
|
||||||
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_lerp"] = "img_lerp"
|
type: Literal["img_lerp"] = "img_lerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -512,25 +451,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Inverse Lerp Image")
|
||||||
|
@tags("image", "ilerp")
|
||||||
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_ilerp"] = "img_ilerp"
|
type: Literal["img_ilerp"] = "img_ilerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Image Inverse Linear Interpolation",
|
|
||||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -556,21 +488,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Blur NSFW Image")
|
||||||
|
@tags("image", "nsfw")
|
||||||
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_nsfw"] = "img_nsfw"
|
type: Literal["img_nsfw"] = "img_nsfw"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
# fmt: on
|
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -607,22 +537,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Add Invisible Watermark")
|
||||||
|
@tags("image", "watermark")
|
||||||
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
# fmt: off
|
# Metadata
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
type: Literal["img_watermark"] = "img_watermark"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
# fmt: on
|
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -644,21 +572,23 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Mask Edge")
|
||||||
|
@tags("image", "mask", "inpaint")
|
||||||
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask_edge"] = "mask_edge"
|
type: Literal["mask_edge"] = "mask_edge"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
|
image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
edge_size: int = Field(description="The size of the edge")
|
edge_size: int = InputField(description="The size of the edge")
|
||||||
edge_blur: int = Field(description="The amount of blur on the edge")
|
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||||
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
|
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||||
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
|
high_threshold: int = InputField(
|
||||||
# fmt: on
|
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||||
@ -683,28 +613,23 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MaskOutput(
|
return ImageOutput(
|
||||||
mask=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Combine Mask")
|
||||||
|
@tags("image", "mask", "multiply")
|
||||||
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask_combine"] = "mask_combine"
|
type: Literal["mask_combine"] = "mask_combine"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask1: ImageField = Field(default=None, description="The first mask to combine")
|
mask1: ImageField = InputField(description="The first mask to combine")
|
||||||
mask2: ImageField = Field(default=None, description="The second image to combine")
|
mask2: ImageField = InputField(description="The second image to combine")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Mask Combine", "tags": ["mask", "combine"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||||
@ -728,7 +653,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
@title("Color Correct")
|
||||||
|
@tags("image", "color")
|
||||||
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
using a mask to only color-correct certain regions of the target image.
|
using a mask to only color-correct certain regions of the target image.
|
||||||
@ -736,10 +663,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
type: Literal["color_correct"] = "color_correct"
|
type: Literal["color_correct"] = "color_correct"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to color-correct")
|
# Inputs
|
||||||
reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction")
|
image: ImageField = InputField(description="The image to color-correct")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction")
|
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||||
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
|
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||||
|
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_init_mask = None
|
pil_init_mask = None
|
||||||
@ -833,16 +761,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Hue Adjustment")
|
||||||
|
@tags("image", "hue", "hsl")
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
|
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -877,16 +805,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Luminosity Adjustment")
|
||||||
|
@tags("image", "luminosity", "hsl")
|
||||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Luminosity (Value) of an image."""
|
"""Adjusts the Luminosity (Value) of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
|
luminosity: float = InputField(
|
||||||
# fmt: on
|
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -925,16 +855,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Saturation Adjustment")
|
||||||
|
@tags("image", "saturation", "hsl")
|
||||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Saturation of an image."""
|
"""Adjusts the Saturation of an image."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import Field
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
|
||||||
|
|
||||||
from invokeai.app.invocations.image import ImageOutput
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||||
BaseInvocation,
|
|
||||||
InvocationConfig,
|
|
||||||
InvocationContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
def infill_methods() -> list[str]:
|
||||||
@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
|
@title("Solid Color Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
|
||||||
color: ColorField = Field(
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
color: ColorField = InputField(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Tile Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
# Input
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
seed: int = Field(
|
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||||
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("PatchMatch Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -13,16 +13,25 @@ from diffusers.models.attention_processor import (
|
|||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
|
from invokeai.app.invocations.primitives import (
|
||||||
|
ImageField,
|
||||||
|
ImageOutput,
|
||||||
|
LatentsField,
|
||||||
|
LatentsOutput,
|
||||||
|
build_latents_output,
|
||||||
|
)
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
@ -32,48 +41,27 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
|
||||||
"""A latents field used for passing latents between invocations"""
|
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
|
||||||
seed: Optional[int] = Field(description="Seed used to generate this latents")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["latents_name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output latents"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["latents_output"] = "latents_output"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
|
||||||
width: int = Field(description="The width of the latents in pixels")
|
|
||||||
height: int = Field(description="The height of the latents in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
|
|
||||||
return LatentsOutput(
|
|
||||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
|
||||||
width=latents.size()[3] * 8,
|
|
||||||
height=latents.size()[2] * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@ -111,30 +99,36 @@ def get_scheduler(
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@title("Denoise Latents")
|
||||||
|
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
type: Literal["denoise_latents"] = "denoise_latents"
|
type: Literal["denoise_latents"] = "denoise_latents"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
positive_conditioning: ConditioningField = InputField(
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
|
||||||
cfg_scale: Union[float, List[float]] = Field(
|
|
||||||
default=7.5,
|
|
||||||
ge=1,
|
|
||||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
|
||||||
)
|
)
|
||||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
negative_conditioning: ConditioningField = InputField(
|
||||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
)
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
mask: Optional[ImageField] = Field(
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
||||||
None,
|
)
|
||||||
description="Mask",
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||||
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||||
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
|
mask: Optional[ImageField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -149,20 +143,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Denoise Latents",
|
|
||||||
"tags": ["denoise", "latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
"control": "control",
|
|
||||||
"cfg_scale": "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self,
|
||||||
@ -474,29 +454,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
@title("Latents to Image")
|
||||||
|
@tags("latents", "image", "vae")
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: LatentsField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description=FieldDescriptions.latents,
|
||||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
input=Input.Connection,
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
)
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
vae: VaeField = InputField(
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
metadata: CoreMetadata = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.core_metadata,
|
||||||
|
ui_hidden=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Latents To Image",
|
|
||||||
"tags": ["latents", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -574,24 +554,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
|
@title("Resize Latents")
|
||||||
|
@tags("latents", "resize")
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: LatentsField = InputField(
|
||||||
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
description=FieldDescriptions.latents,
|
||||||
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
input=Input.Connection,
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
|
||||||
)
|
)
|
||||||
|
width: int = InputField(
|
||||||
class Config(InvocationConfig):
|
ge=64,
|
||||||
schema_extra = {
|
multiple_of=8,
|
||||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
description=FieldDescriptions.width,
|
||||||
}
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
ge=64,
|
||||||
|
multiple_of=8,
|
||||||
|
description=FieldDescriptions.width,
|
||||||
|
)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -616,23 +602,21 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Scale Latents")
|
||||||
|
@tags("latents", "resize")
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: LatentsField = InputField(
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
description=FieldDescriptions.latents,
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
input=Input.Connection,
|
||||||
antialias: bool = Field(
|
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
|
||||||
)
|
)
|
||||||
|
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||||
class Config(InvocationConfig):
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
schema_extra = {
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -658,22 +642,23 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image to Latents")
|
||||||
|
@tags("latents", "image", "vae")
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
type: Literal["i2l"] = "i2l"
|
type: Literal["i2l"] = "i2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The image to encode")
|
image: ImageField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description="The image to encode",
|
||||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
)
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
vae: VaeField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
# Schema customisation
|
input=Input.Connection,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
@ -2,134 +2,83 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .baseinvocation import (
|
from invokeai.app.invocations.primitives import IntegerOutput
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||||
InvocationContext,
|
|
||||||
InvocationConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MathInvocationConfig(BaseModel):
|
@title("Add Integers")
|
||||||
"""Helper class to provide all math invocations with additional config"""
|
@tags("math")
|
||||||
|
class AddInvocation(BaseInvocation):
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["math"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class IntOutput(BaseInvocationOutput):
|
|
||||||
"""An integer output"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["int_output"] = "int_output"
|
|
||||||
a: int = Field(default=None, description="The output integer")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class FloatOutput(BaseInvocationOutput):
|
|
||||||
"""A float output"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["float_output"] = "float_output"
|
|
||||||
param: float = Field(default=None, description="The output float")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["add"] = "add"
|
type: Literal["add"] = "add"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntOutput(a=self.a + self.b)
|
return IntegerOutput(a=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Subtract Integers")
|
||||||
|
@tags("math")
|
||||||
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sub"] = "sub"
|
type: Literal["sub"] = "sub"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntOutput(a=self.a - self.b)
|
return IntegerOutput(a=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Multiply Integers")
|
||||||
|
@tags("math")
|
||||||
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mul"] = "mul"
|
type: Literal["mul"] = "mul"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntOutput(a=self.a * self.b)
|
return IntegerOutput(a=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
@title("Divide Integers")
|
||||||
|
@tags("math")
|
||||||
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["div"] = "div"
|
type: Literal["div"] = "div"
|
||||||
a: int = Field(default=0, description="The first number")
|
|
||||||
b: int = Field(default=0, description="The second number")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntOutput(a=int(self.a / self.b))
|
return IntegerOutput(a=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
|
@title("Random Integer")
|
||||||
|
@tags("math")
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["rand_int"] = "rand_int"
|
type: Literal["rand_int"] = "rand_int"
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
|
||||||
high: int = Field(
|
|
||||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ...version import __version__
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationConfig,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
from ...version import __version__
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModelExcludeNull):
|
class LoRAMetadataField(BaseModelExcludeNull):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
@ -43,37 +47,37 @@ class CoreMetadata(BaseModelExcludeNull):
|
|||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Optional[VAEModelField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Latents-to-Latents
|
# Latents-to-Latents
|
||||||
strength: Union[float, None] = Field(
|
strength: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
refiner_positive_aesthetic_store: Union[float, None] = Field(
|
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Union[float, None] = Field(
|
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModelExcludeNull):
|
class ImageMetadata(BaseModelExcludeNull):
|
||||||
@ -91,69 +95,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||||
|
|
||||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
|
@title("Metadata Accumulator")
|
||||||
|
@tags("metadata")
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
"""Outputs a Core Metadata Object"""
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||||
|
|
||||||
generation_mode: str = Field(
|
generation_mode: str = InputField(
|
||||||
description="The generation mode that output this image",
|
description="The generation mode that output this image",
|
||||||
)
|
)
|
||||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||||
width: int = Field(description="The width parameter")
|
width: int = InputField(description="The width parameter")
|
||||||
height: int = Field(description="The height parameter")
|
height: int = InputField(description="The height parameter")
|
||||||
seed: int = Field(description="The seed used for noise generation")
|
seed: int = InputField(description="The seed used for noise generation")
|
||||||
rand_device: str = Field(description="The device used for random number generation")
|
rand_device: str = InputField(description="The device used for random number generation")
|
||||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||||
steps: int = Field(description="The number of steps used for inference")
|
steps: int = InputField(description="The number of steps used for inference")
|
||||||
scheduler: str = Field(description="The scheduler used for inference")
|
scheduler: str = InputField(description="The scheduler used for inference")
|
||||||
clip_skip: int = Field(
|
clip_skip: int = InputField(
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = InputField(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||||
strength: Union[float, None] = Field(
|
strength: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Optional[str] = InputField(
|
||||||
vae: Union[VAEModelField, None] = Field(
|
default=None,
|
||||||
|
description="The name of the initial image",
|
||||||
|
)
|
||||||
|
vae: Optional[VAEModelField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Optional[str] = InputField(
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
default=None,
|
||||||
|
description="The positive style prompt parameter",
|
||||||
|
)
|
||||||
|
negative_style_prompt: Optional[str] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The negative style prompt parameter",
|
||||||
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Optional[MainModelField] = InputField(
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
default=None,
|
||||||
|
description="The SDXL Refiner model used",
|
||||||
|
)
|
||||||
|
refiner_cfg_scale: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = InputField(
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
default=None,
|
||||||
refiner_positive_aesthetic_score: Union[float, None] = Field(
|
description="The number of steps used for the refiner",
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_score: Union[float, None] = Field(
|
refiner_scheduler: Optional[str] = InputField(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None,
|
||||||
|
description="The scheduler used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The aesthetic score used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The aesthetic score used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_start: Optional[float] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The start value used for refiner denoising",
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Metadata Accumulator",
|
|
||||||
"tags": ["image", "metadata", "generation"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
@ -39,13 +50,11 @@ class VaeField(BaseModel):
|
|||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["model_loader_output"] = "model_loader_output"
|
type: Literal["model_loader_output"] = "model_loader_output"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
@title("Main Model Loader")
|
||||||
|
@tags("model")
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["main_model_loader"] = "main_model_loader"
|
type: Literal["main_model_loader"] = "main_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Model Loader",
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
loras=[],
|
loras=[],
|
||||||
skipped_layers=0,
|
skipped_layers=0,
|
||||||
),
|
),
|
||||||
clip2=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Tokenizer2,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.TextEncoder2,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@title("LoRA Loader")
|
||||||
|
@tags("lora", "model")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
# Inputs
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = InputField(
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
clip: Optional[ClipField] = InputField(
|
||||||
schema_extra = {
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||||
"ui": {
|
)
|
||||||
"title": "Lora Loader",
|
|
||||||
"tags": ["lora", "loader"],
|
|
||||||
"type_hints": {"lora": "lora_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""SDXL LoRA Loader Output"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL LoRA Loader")
|
||||||
|
@tags("sdxl", "lora", "model")
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
|
unet: Optional[UNetField] = Field(
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
)
|
||||||
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
clip: Optional[ClipField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
clip2: Optional[ClipField] = Field(
|
||||||
"ui": {
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||||
"title": "SDXL Lora Loader",
|
)
|
||||||
"tags": ["lora", "loader"],
|
|
||||||
"type_hints": {"lora": "lora_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
@ -369,29 +350,23 @@ class VAEModelField(BaseModel):
|
|||||||
class VaeLoaderOutput(BaseInvocationOutput):
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||||
|
|
||||||
vae: VaeField = Field(default=None, description="Vae model")
|
# Outputs
|
||||||
# fmt: on
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@title("VAE Loader")
|
||||||
|
@tags("vae", "model")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
type: Literal["vae_loader"] = "vae_loader"
|
type: Literal["vae_loader"] = "vae_loader"
|
||||||
|
|
||||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
# Inputs
|
||||||
|
vae_model: VAEModelField = InputField(
|
||||||
# Schema customisation
|
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "VAE Loader",
|
|
||||||
"tags": ["vae", "loader"],
|
|
||||||
"type_hints": {"vae_model": "vae_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
base_model = self.vae_model.base_model
|
base_model = self.vae_model.base_model
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field, validator
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from pydantic import validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationConfig,
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -61,14 +66,12 @@ Nodes
|
|||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
|
|
||||||
# fmt: off
|
type: Literal["noise_output"] = "noise_output"
|
||||||
type: Literal["noise_output"] = "noise_output"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||||
width: int = Field(description="The width of the noise in pixels")
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = Field(description="The height of the noise in pixels")
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||||
@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Noise")
|
||||||
|
@tags("latents", "noise")
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
description="The seed to use",
|
description=FieldDescriptions.seed,
|
||||||
default_factory=get_random_seed,
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
width: int = Field(
|
width: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=8,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="The width of the resulting noise",
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = Field(
|
height: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=8,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="The height of the resulting noise",
|
description=FieldDescriptions.height,
|
||||||
)
|
)
|
||||||
use_cpu: bool = Field(
|
use_cpu: bool = InputField(
|
||||||
default=True,
|
default=True,
|
||||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Noise",
|
|
||||||
"tags": ["latents", "noise"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
@validator("seed", pre=True)
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
|
@ -1,37 +1,43 @@
|
|||||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import re
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from tqdm import tqdm
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
|
||||||
from ...backend.util import choose_torch_device
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .image import ImageOutput
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util import choose_torch_device
|
||||||
from tqdm import tqdm
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .model import ClipField
|
from .baseinvocation import (
|
||||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
BaseInvocation,
|
||||||
from .compel import CompelOutput
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
Input,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||||
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
ORT_TO_NP_TYPE = {
|
||||||
"tensor(bool)": np.bool_,
|
"tensor(bool)": np.bool_,
|
||||||
@ -51,13 +57,15 @@ ORT_TO_NP_TYPE = {
|
|||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
|
@title("ONNX Prompt (Raw)")
|
||||||
|
@tags("onnx", "prompt")
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(),
|
||||||
)
|
)
|
||||||
@ -126,7 +134,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||||
|
|
||||||
return CompelOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=conditioning_name,
|
conditioning_name=conditioning_name,
|
||||||
),
|
),
|
||||||
@ -134,25 +142,48 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
|
@title("ONNX Text to Latents")
|
||||||
|
@tags("latents", "inference", "txt2img", "onnx")
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
positive_conditioning: ConditioningField = InputField(
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
description=FieldDescriptions.positive_cond,
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
input=Input.Connection,
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
)
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
negative_conditioning: ConditioningField = InputField(
|
||||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
description=FieldDescriptions.negative_cond,
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
input=Input.Connection,
|
||||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
)
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
noise: LatentsField = InputField(
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
description=FieldDescriptions.noise,
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
input=Input.Connection,
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
)
|
||||||
# fmt: on
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
|
default=7.5,
|
||||||
|
ge=1,
|
||||||
|
description=FieldDescriptions.cfg_scale,
|
||||||
|
ui_type=UIType.Float,
|
||||||
|
)
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||||
|
)
|
||||||
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||||
|
unet: UNetField = InputField(
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.control,
|
||||||
|
ui_type=UIType.Control,
|
||||||
|
)
|
||||||
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
@ -166,20 +197,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
"control": "control",
|
|
||||||
# "cfg_scale": "float",
|
|
||||||
"cfg_scale": "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
@ -300,26 +317,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
|
@title("ONNX Latents to Image")
|
||||||
|
@tags("latents", "image", "vae", "onnx")
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: LatentsField = InputField(
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
description=FieldDescriptions.denoised_latents,
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
input=Input.Connection,
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
|
||||||
)
|
)
|
||||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
vae: VaeField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
# Schema customisation
|
input=Input.Connection,
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
"ui": {
|
default=None,
|
||||||
"tags": ["latents", "image"],
|
description=FieldDescriptions.core_metadata,
|
||||||
},
|
ui_hidden=True,
|
||||||
}
|
)
|
||||||
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -373,89 +392,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
||||||
model_name = "stable-diffusion-v1-5"
|
|
||||||
base_model = BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {model_name}!")
|
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae_decoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeDecoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
vae_encoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeEncoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelField(BaseModel):
|
class OnnxModelField(BaseModel):
|
||||||
"""Onnx model field"""
|
"""Onnx model field"""
|
||||||
|
|
||||||
@ -464,22 +407,17 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
|
@title("ONNX Model Loader")
|
||||||
|
@tags("onnx", "model")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||||
|
|
||||||
model: OnnxModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: OnnxModelField = InputField(
|
||||||
# Schema customisation
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||||
class Config(InvocationConfig):
|
)
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Onnx Model Loader",
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
|
@ -1,73 +1,64 @@
|
|||||||
import io
|
import io
|
||||||
from typing import Literal, Optional, Any
|
from typing import Literal, Optional
|
||||||
|
|
||||||
# from PIL.Image import Image
|
|
||||||
import PIL.Image
|
|
||||||
from matplotlib.ticker import MaxNLocator
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
from easing_functions import (
|
from easing_functions import (
|
||||||
LinearInOut,
|
|
||||||
QuadEaseInOut,
|
|
||||||
QuadEaseIn,
|
|
||||||
QuadEaseOut,
|
|
||||||
CubicEaseInOut,
|
|
||||||
CubicEaseIn,
|
|
||||||
CubicEaseOut,
|
|
||||||
QuarticEaseInOut,
|
|
||||||
QuarticEaseIn,
|
|
||||||
QuarticEaseOut,
|
|
||||||
QuinticEaseInOut,
|
|
||||||
QuinticEaseIn,
|
|
||||||
QuinticEaseOut,
|
|
||||||
SineEaseInOut,
|
|
||||||
SineEaseIn,
|
|
||||||
SineEaseOut,
|
|
||||||
CircularEaseIn,
|
|
||||||
CircularEaseInOut,
|
|
||||||
CircularEaseOut,
|
|
||||||
ExponentialEaseInOut,
|
|
||||||
ExponentialEaseIn,
|
|
||||||
ExponentialEaseOut,
|
|
||||||
ElasticEaseIn,
|
|
||||||
ElasticEaseInOut,
|
|
||||||
ElasticEaseOut,
|
|
||||||
BackEaseIn,
|
BackEaseIn,
|
||||||
BackEaseInOut,
|
BackEaseInOut,
|
||||||
BackEaseOut,
|
BackEaseOut,
|
||||||
BounceEaseIn,
|
BounceEaseIn,
|
||||||
BounceEaseInOut,
|
BounceEaseInOut,
|
||||||
BounceEaseOut,
|
BounceEaseOut,
|
||||||
|
CircularEaseIn,
|
||||||
|
CircularEaseInOut,
|
||||||
|
CircularEaseOut,
|
||||||
|
CubicEaseIn,
|
||||||
|
CubicEaseInOut,
|
||||||
|
CubicEaseOut,
|
||||||
|
ElasticEaseIn,
|
||||||
|
ElasticEaseInOut,
|
||||||
|
ElasticEaseOut,
|
||||||
|
ExponentialEaseIn,
|
||||||
|
ExponentialEaseInOut,
|
||||||
|
ExponentialEaseOut,
|
||||||
|
LinearInOut,
|
||||||
|
QuadEaseIn,
|
||||||
|
QuadEaseInOut,
|
||||||
|
QuadEaseOut,
|
||||||
|
QuarticEaseIn,
|
||||||
|
QuarticEaseInOut,
|
||||||
|
QuarticEaseOut,
|
||||||
|
QuinticEaseIn,
|
||||||
|
QuinticEaseInOut,
|
||||||
|
QuinticEaseOut,
|
||||||
|
SineEaseIn,
|
||||||
|
SineEaseInOut,
|
||||||
|
SineEaseOut,
|
||||||
)
|
)
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||||
|
|
||||||
from .baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
InvocationContext,
|
|
||||||
InvocationConfig,
|
|
||||||
)
|
|
||||||
from ...backend.util.logging import InvokeAILogger
|
from ...backend.util.logging import InvokeAILogger
|
||||||
from .collections import FloatCollectionOutput
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
|
@title("Float Range")
|
||||||
|
@tags("math", "range")
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
type: Literal["float_range"] = "float_range"
|
type: Literal["float_range"] = "float_range"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
start: float = Field(default=5, description="The first value of the range")
|
start: float = InputField(default=5, description="The first value of the range")
|
||||||
stop: float = Field(default=10, description="The last value of the range")
|
stop: float = InputField(default=10, description="The last value of the range")
|
||||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
@ -108,37 +99,32 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
|
@title("Step Param Easing")
|
||||||
|
@tags("step", "easing")
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
type: Literal["step_param_easing"] = "step_param_easing"
|
type: Literal["step_param_easing"] = "step_param_easing"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||||
start_value: float = Field(default=0.0, description="easing starting value")
|
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||||
end_value: float = Field(default=1.0, description="easing ending value")
|
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
|
||||||
# if None, then start_value is used prior to easing start
|
# if None, then start_value is used prior to easing start
|
||||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||||
# if None, then end value is used prior to easing end
|
# if None, then end value is used prior to easing end
|
||||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
log_diagnostics = False
|
log_diagnostics = False
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.prompt import PromptOutput
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .math import FloatOutput, IntOutput
|
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
|
||||||
|
|
||||||
|
|
||||||
class ParamIntInvocation(BaseInvocation):
|
|
||||||
"""An integer parameter"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_int"] = "param_int"
|
|
||||||
a: int = Field(default=0, description="The integer value")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
|
||||||
return IntOutput(a=self.a)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamFloatInvocation(BaseInvocation):
|
|
||||||
"""A float parameter"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_float"] = "param_float"
|
|
||||||
param: float = Field(default=0.0, description="The float value")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
|
||||||
return FloatOutput(param=self.param)
|
|
||||||
|
|
||||||
|
|
||||||
class StringOutput(BaseInvocationOutput):
|
|
||||||
"""A string output"""
|
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
|
||||||
text: str = Field(default=None, description="The output string")
|
|
||||||
|
|
||||||
|
|
||||||
class ParamStringInvocation(BaseInvocation):
|
|
||||||
"""A string parameter"""
|
|
||||||
|
|
||||||
type: Literal["param_string"] = "param_string"
|
|
||||||
text: str = Field(default="", description="The string value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
|
||||||
return StringOutput(text=self.text)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamPromptInvocation(BaseInvocation):
|
|
||||||
"""A prompt input parameter"""
|
|
||||||
|
|
||||||
type: Literal["param_prompt"] = "param_prompt"
|
|
||||||
prompt: str = Field(default="", description="The prompt value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
|
||||||
return PromptOutput(prompt=self.prompt)
|
|
494
invokeai/app/invocations/primitives.py
Normal file
494
invokeai/app/invocations/primitives.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal, Optional, Tuple, Union
|
||||||
|
from anyio import Condition
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
||||||
|
- primitive nodes
|
||||||
|
- primitive outputs
|
||||||
|
- primitive collection outputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
# region Boolean
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single boolean"""
|
||||||
|
|
||||||
|
type: Literal["boolean_output"] = "boolean_output"
|
||||||
|
a: bool = OutputField(description="The output boolean")
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of booleans"""
|
||||||
|
|
||||||
|
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[bool] = OutputField(
|
||||||
|
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Boolean Primitive")
|
||||||
|
@tags("primitives", "boolean")
|
||||||
|
class BooleanInvocation(BaseInvocation):
|
||||||
|
"""A boolean primitive value"""
|
||||||
|
|
||||||
|
type: Literal["boolean"] = "boolean"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
a: bool = InputField(default=False, description="The boolean value")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||||
|
return BooleanOutput(a=self.a)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Boolean Primitive Collection")
|
||||||
|
@tags("primitives", "boolean", "collection")
|
||||||
|
class BooleanCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of boolean primitive values"""
|
||||||
|
|
||||||
|
type: Literal["boolean_collection"] = "boolean_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[bool] = InputField(
|
||||||
|
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||||
|
return BooleanCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Integer
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single integer"""
|
||||||
|
|
||||||
|
type: Literal["integer_output"] = "integer_output"
|
||||||
|
a: int = OutputField(description="The output integer")
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of integers"""
|
||||||
|
|
||||||
|
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[int] = OutputField(
|
||||||
|
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Primitive")
|
||||||
|
@tags("primitives", "integer")
|
||||||
|
class IntegerInvocation(BaseInvocation):
|
||||||
|
"""An integer primitive value"""
|
||||||
|
|
||||||
|
type: Literal["integer"] = "integer"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
a: int = InputField(default=0, description="The integer value")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
|
return IntegerOutput(a=self.a)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Integer Primitive Collection")
|
||||||
|
@tags("primitives", "integer", "collection")
|
||||||
|
class IntegerCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of integer primitive values"""
|
||||||
|
|
||||||
|
type: Literal["integer_collection"] = "integer_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[int] = InputField(
|
||||||
|
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
|
return IntegerCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Float
|
||||||
|
|
||||||
|
|
||||||
|
class FloatOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single float"""
|
||||||
|
|
||||||
|
type: Literal["float_output"] = "float_output"
|
||||||
|
a: float = OutputField(description="The output float")
|
||||||
|
|
||||||
|
|
||||||
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of floats"""
|
||||||
|
|
||||||
|
type: Literal["float_collection_output"] = "float_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[float] = OutputField(
|
||||||
|
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Float Primitive")
|
||||||
|
@tags("primitives", "float")
|
||||||
|
class FloatInvocation(BaseInvocation):
|
||||||
|
"""A float primitive value"""
|
||||||
|
|
||||||
|
type: Literal["float"] = "float"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
param: float = InputField(default=0.0, description="The float value")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
|
return FloatOutput(a=self.param)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Float Primitive Collection")
|
||||||
|
@tags("primitives", "float", "collection")
|
||||||
|
class FloatCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of float primitive values"""
|
||||||
|
|
||||||
|
type: Literal["float_collection"] = "float_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[float] = InputField(
|
||||||
|
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
|
return FloatCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region String
|
||||||
|
|
||||||
|
|
||||||
|
class StringOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single string"""
|
||||||
|
|
||||||
|
type: Literal["string_output"] = "string_output"
|
||||||
|
text: str = OutputField(description="The output string")
|
||||||
|
|
||||||
|
|
||||||
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of strings"""
|
||||||
|
|
||||||
|
type: Literal["string_collection_output"] = "string_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[str] = OutputField(
|
||||||
|
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("String Primitive")
|
||||||
|
@tags("primitives", "string")
|
||||||
|
class StringInvocation(BaseInvocation):
|
||||||
|
"""A string primitive value"""
|
||||||
|
|
||||||
|
type: Literal["string"] = "string"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
|
return StringOutput(text=self.text)
|
||||||
|
|
||||||
|
|
||||||
|
@title("String Primitive Collection")
|
||||||
|
@tags("primitives", "string", "collection")
|
||||||
|
class StringCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of string primitive values"""
|
||||||
|
|
||||||
|
type: Literal["string_collection"] = "string_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[str] = InputField(
|
||||||
|
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
|
return StringCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageField(BaseModel):
|
||||||
|
"""An image primitive field"""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
|
type: Literal["image_output"] = "image_output"
|
||||||
|
image: ImageField = OutputField(description="The output image")
|
||||||
|
width: int = OutputField(description="The width of the image in pixels")
|
||||||
|
height: int = OutputField(description="The height of the image in pixels")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of images"""
|
||||||
|
|
||||||
|
type: Literal["image_collection_output"] = "image_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[ImageField] = OutputField(
|
||||||
|
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Primitive")
|
||||||
|
@tags("primitives", "image")
|
||||||
|
class ImageInvocation(BaseInvocation):
|
||||||
|
"""An image primitive value"""
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=self.image.image_name),
|
||||||
|
width=image.width,
|
||||||
|
height=image.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Image Primitive Collection")
|
||||||
|
@tags("primitives", "image", "collection")
|
||||||
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of image primitive values"""
|
||||||
|
|
||||||
|
type: Literal["image_collection"] = "image_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[ImageField] = InputField(
|
||||||
|
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||||
|
return ImageCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Latents
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsField(BaseModel):
|
||||||
|
"""A latents tensor primitive field"""
|
||||||
|
|
||||||
|
latents_name: str = Field(description="The name of the latents")
|
||||||
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single latents tensor"""
|
||||||
|
|
||||||
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
|
latents: LatentsField = OutputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
)
|
||||||
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of latents tensors"""
|
||||||
|
|
||||||
|
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||||
|
|
||||||
|
collection: list[LatentsField] = OutputField(
|
||||||
|
default_factory=list,
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
ui_type=UIType.LatentsCollection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Latents Primitive")
|
||||||
|
@tags("primitives", "latents")
|
||||||
|
class LatentsInvocation(BaseInvocation):
|
||||||
|
"""A latents tensor primitive value"""
|
||||||
|
|
||||||
|
type: Literal["latents"] = "latents"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
return build_latents_output(self.latents.latents_name, latents)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Latents Primitive Collection")
|
||||||
|
@tags("primitives", "latents", "collection")
|
||||||
|
class LatentsCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of latents tensor primitive values"""
|
||||||
|
|
||||||
|
type: Literal["latents_collection"] = "latents_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[LatentsField] = InputField(
|
||||||
|
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||||
|
return LatentsCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Color
|
||||||
|
|
||||||
|
|
||||||
|
class ColorField(BaseModel):
|
||||||
|
"""A color primitive field"""
|
||||||
|
|
||||||
|
r: int = Field(ge=0, le=255, description="The red component")
|
||||||
|
g: int = Field(ge=0, le=255, description="The green component")
|
||||||
|
b: int = Field(ge=0, le=255, description="The blue component")
|
||||||
|
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||||
|
|
||||||
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single color"""
|
||||||
|
|
||||||
|
type: Literal["color_output"] = "color_output"
|
||||||
|
color: ColorField = OutputField(description="The output color")
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of colors"""
|
||||||
|
|
||||||
|
type: Literal["color_collection_output"] = "color_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[ColorField] = OutputField(
|
||||||
|
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Color Primitive")
|
||||||
|
@tags("primitives", "color")
|
||||||
|
class ColorInvocation(BaseInvocation):
|
||||||
|
"""A color primitive value"""
|
||||||
|
|
||||||
|
type: Literal["color"] = "color"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||||
|
return ColorOutput(color=self.color)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Conditioning
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningField(BaseModel):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
|
type: Literal["conditioning_output"] = "conditioning_output"
|
||||||
|
|
||||||
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||||
|
|
||||||
|
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[ConditioningField] = OutputField(
|
||||||
|
default_factory=list,
|
||||||
|
description="The output conditioning tensors",
|
||||||
|
ui_type=UIType.ConditioningCollection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Conditioning Primitive")
|
||||||
|
@tags("primitives", "conditioning")
|
||||||
|
class ConditioningInvocation(BaseInvocation):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
type: Literal["conditioning"] = "conditioning"
|
||||||
|
|
||||||
|
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
|
return ConditioningOutput(conditioning=self.conditioning)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Conditioning Primitive Collection")
|
||||||
|
@tags("primitives", "conditioning", "collection")
|
||||||
|
class ConditioningCollectionInvocation(BaseInvocation):
|
||||||
|
"""A collection of conditioning tensor primitive values"""
|
||||||
|
|
||||||
|
type: Literal["conditioning_collection"] = "conditioning_collection"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
collection: list[ConditioningField] = InputField(
|
||||||
|
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||||
|
return ConditioningCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
@ -1,59 +1,28 @@
|
|||||||
from os.path import exists
|
from os.path import exists
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field, validator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
|
||||||
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
|
||||||
|
|
||||||
class PromptOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a prompt"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["prompt"] = "prompt"
|
|
||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"prompt",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PromptCollectionOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a collection of prompts"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
|
||||||
|
|
||||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
|
||||||
count: int = Field(description="The size of the prompt collection")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("Dynamic Prompt")
|
||||||
|
@tags("prompt", "collection")
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
|
||||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
|
||||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
}
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
if self.combinatorial:
|
if self.combinatorial:
|
||||||
generator = CombinatorialPromptGenerator()
|
generator = CombinatorialPromptGenerator()
|
||||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||||
@ -61,27 +30,26 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
generator = RandomPromptGenerator()
|
generator = RandomPromptGenerator()
|
||||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||||
|
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Prompts from File")
|
||||||
|
@tags("prompt", "file")
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
# fmt: off
|
type: Literal["prompt_from_file"] = "prompt_from_file"
|
||||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
file_path: str = Field(description="Path to prompt text file")
|
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
|
||||||
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
pre_prompt: Optional[str] = InputField(
|
||||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
)
|
||||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
post_prompt: Optional[str] = InputField(
|
||||||
# fmt: on
|
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||||
|
)
|
||||||
class Config(InvocationConfig):
|
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||||
schema_extra = {
|
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("file_path")
|
@validator("file_path")
|
||||||
def file_path_exists(cls, v):
|
def file_path_exists(cls, v):
|
||||||
@ -89,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
raise ValueError(FileNotFoundError)
|
raise ValueError(FileNotFoundError)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
def promptsFromFile(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
pre_prompt: Union[str, None],
|
||||||
|
post_prompt: Union[str, None],
|
||||||
|
start_line: int,
|
||||||
|
max_prompts: int,
|
||||||
|
):
|
||||||
prompts = []
|
prompts = []
|
||||||
start_line -= 1
|
start_line -= 1
|
||||||
end_line = start_line + max_prompts
|
end_line = start_line + max_prompts
|
||||||
@ -103,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
break
|
break
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
prompts = self.promptsFromFile(
|
prompts = self.promptsFromFile(
|
||||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||||
)
|
)
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
@ -1,55 +1,55 @@
|
|||||||
import torch
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (
|
||||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
# fmt: on
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Main Model Loader")
|
||||||
|
@tags("model", "sdxl")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(
|
||||||
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||||
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Model Loader",
|
|
||||||
"tags": ["model", "loader", "sdxl"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("SDXL Refiner Model Loader")
|
||||||
|
@tags("model", "sdxl", "refiner")
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
# Inputs
|
||||||
|
model: MainModelField = InputField(
|
||||||
|
description=FieldDescriptions.sdxl_refiner_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_type=UIType.SDXLRefinerModel,
|
||||||
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Refiner Model Loader",
|
|
||||||
"tags": ["model", "loader", "sdxl_refiner"],
|
|
||||||
"type_hints": {"model": "refiner_model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
|
@ -6,13 +6,12 @@ import cv2 as cv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import Field
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
# TODO: Use model manager to load?
|
# TODO: Use model manager to load?
|
||||||
@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@title("Upscale (RealESRGAN)")
|
||||||
|
@tags("esrgan", "upscale")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
type: Literal["esrgan"] = "esrgan"
|
type: Literal["esrgan"] = "esrgan"
|
||||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
|
||||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
# Inputs
|
||||||
schema_extra = {
|
image: ImageField = InputField(description="The input image")
|
||||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
@ -1,31 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Tuple, Literal
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from ..invocations.baseinvocation import (
|
|
||||||
BaseInvocationOutput,
|
|
||||||
InvocationConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
|
||||||
"""An image field used for passing image objects between invocations"""
|
|
||||||
|
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["image_name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
|
||||||
r: int = Field(ge=0, le=255, description="The red component")
|
|
||||||
g: int = Field(ge=0, le=255, description="The green component")
|
|
||||||
b: int = Field(ge=0, le=255, description="The blue component")
|
|
||||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
|
||||||
|
|
||||||
def tuple(self) -> Tuple[int, int, int, int]:
|
|
||||||
return (self.r, self.g, self.b, self.a)
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
class ProgressImage(BaseModel):
|
||||||
@ -36,50 +13,6 @@ class ProgressImage(BaseModel):
|
|||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
|
||||||
"""Helper class to provide all PIL invocations with additional config"""
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["PIL", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_output"] = "image_output"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
|
||||||
|
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a mask"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask"] = "mask"
|
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
|
||||||
width: int = Field(description="The width of the mask in pixels")
|
|
||||||
height: int = Field(description="The height of the mask in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"mask",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
"""The origin of a resource (eg image).
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocat
|
|||||||
from ..invocations.image import ImageNSFWBlurInvocation
|
from ..invocations.image import ImageNSFWBlurInvocation
|
||||||
from ..invocations.noise import NoiseInvocation
|
from ..invocations.noise import NoiseInvocation
|
||||||
from ..invocations.compel import CompelInvocation
|
from ..invocations.compel import CompelInvocation
|
||||||
from ..invocations.params import ParamIntInvocation
|
from ..invocations.primitives import IntegerInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
description="Converts text to an image",
|
description="Converts text to an image",
|
||||||
graph=Graph(
|
graph=Graph(
|
||||||
nodes={
|
nodes={
|
||||||
"width": ParamIntInvocation(id="width", a=512),
|
"width": IntegerInvocation(id="width", a=512),
|
||||||
"height": ParamIntInvocation(id="height", a=512),
|
"height": IntegerInvocation(id="height", a=512),
|
||||||
"seed": ParamIntInvocation(id="seed", a=-1),
|
"seed": IntegerInvocation(id="seed", a=-1),
|
||||||
"3": NoiseInvocation(id="3"),
|
"3": NoiseInvocation(id="3"),
|
||||||
"4": CompelInvocation(id="4"),
|
"4": CompelInvocation(id="4"),
|
||||||
"5": CompelInvocation(id="5"),
|
"5": CompelInvocation(id="5"),
|
||||||
|
@ -3,16 +3,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
get_type_hints,
|
|
||||||
)
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
@ -22,7 +13,11 @@ from ..invocations import *
|
|||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
type: Literal["iterate_output"] = "iterate_output"
|
type: Literal["iterate_output"] = "iterate_output"
|
||||||
|
|
||||||
item: Any = Field(description="The item being iterated over")
|
item: Any = OutputField(
|
||||||
|
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||||
class Config:
|
)
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"item",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["iterate"] = "iterate"
|
type: Literal["iterate"] = "iterate"
|
||||||
|
|
||||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
collection: list[Any] = InputField(
|
||||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||||
|
)
|
||||||
|
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||||
"""Produces the outputs as values"""
|
"""Produces the outputs as values"""
|
||||||
@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation):
|
|||||||
class CollectInvocationOutput(BaseInvocationOutput):
|
class CollectInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["collect_output"] = "collect_output"
|
type: Literal["collect_output"] = "collect_output"
|
||||||
|
|
||||||
collection: list[Any] = Field(description="The collection of input items")
|
collection: list[Any] = OutputField(
|
||||||
|
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||||
class Config:
|
)
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"collection",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["collect"] = "collect"
|
type: Literal["collect"] = "collect"
|
||||||
|
|
||||||
item: Any = Field(
|
item: Any = InputField(
|
||||||
description="The item to collect (all inputs must be of the same type)",
|
description="The item to collect (all inputs must be of the same type)",
|
||||||
default=None,
|
ui_type=UIType.CollectionItem,
|
||||||
|
title="Collection Item",
|
||||||
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
collection: list[Any] = Field(
|
collection: list[Any] = InputField(
|
||||||
description="The collection, will be provided on execution",
|
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||||
default_factory=list,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||||
|
@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
|
|||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
"deleted_at",
|
"deleted_at",
|
||||||
|
"starred",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[dict],
|
metadata: Optional[dict],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
|
starred: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._cursor.execute("PRAGMA table_info(images)")
|
||||||
|
columns = [column[1] for column in self._cursor.fetchall()]
|
||||||
|
|
||||||
|
if "starred" not in columns:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Create the `images` table indices.
|
# Create the `images` table indices.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
# Add trigger for `updated_at`.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
(changes.is_intermediate, image_name),
|
(changes.is_intermediate, image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Change the image's `starred`` state
|
||||||
|
if changes.starred is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET starred = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.starred, image_name),
|
||||||
|
)
|
||||||
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -397,7 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
query_params.append(board_id)
|
query_params.append(board_id)
|
||||||
|
|
||||||
query_pagination = """--sql
|
query_pagination = """--sql
|
||||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Final images query with pagination
|
# Final images query with pagination
|
||||||
@ -500,6 +529,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[dict],
|
metadata: Optional[dict],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
|
starred: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||||
@ -515,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata,
|
metadata,
|
||||||
is_intermediate
|
is_intermediate,
|
||||||
|
starred
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_name,
|
image_name,
|
||||||
@ -529,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
session_id,
|
session_id,
|
||||||
metadata_json,
|
metadata_json,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
|
starred,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull):
|
|||||||
description="The node ID that generated this image, if it is a generated image.",
|
description="The node ID that generated this image, if it is a generated image.",
|
||||||
)
|
)
|
||||||
"""The node ID that generated this image, if it is a generated image."""
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
|
starred: bool = Field(description="Whether this image is starred.")
|
||||||
|
"""Whether this image is starred."""
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||||
@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
- `image_category`: change the category of an image
|
- `image_category`: change the category of an image
|
||||||
- `session_id`: change the session associated with an image
|
- `session_id`: change the session associated with an image
|
||||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||||
|
- `starred`: change whether the image is starred
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||||
@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
"""The image's new session ID."""
|
"""The image's new session ID."""
|
||||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||||
"""The image's new `is_intermediate` flag."""
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state")
|
||||||
|
"""The image's new `starred` state."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
is_intermediate = image_dict.get("is_intermediate", False)
|
is_intermediate = image_dict.get("is_intermediate", False)
|
||||||
|
starred = image_dict.get("starred", False)
|
||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
|
starred=starred,
|
||||||
)
|
)
|
||||||
|
@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
outputs = invocation.invoke(
|
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||||
|
# this accomodates nodes which require a value, but get it only from a
|
||||||
|
# connection
|
||||||
|
outputs = invocation.invoke_internal(
|
||||||
InvocationContext(
|
InvocationContext(
|
||||||
services=self.__invoker.services,
|
services=self.__invoker.services,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
return parse_raw_as(item_type, item)
|
parsed = parse_raw_as(item_type, item)
|
||||||
|
return parsed
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
@ -109,7 +109,7 @@ class ModelMerger(object):
|
|||||||
# pick up the first model's vae
|
# pick up the first model's vae
|
||||||
if mod == model_names[0]:
|
if mod == model_names[0]:
|
||||||
vae = info.get("vae")
|
vae = info.get("vae")
|
||||||
model_paths.extend([config.root_path / info["path"]])
|
model_paths.extend([(config.root_path / info["path"]).as_posix()])
|
||||||
|
|
||||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||||
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
||||||
@ -120,11 +120,11 @@ class ModelMerger(object):
|
|||||||
else config.models_path / base_model.value / ModelType.Main.value
|
else config.models_path / base_model.value / ModelType.Main.value
|
||||||
)
|
)
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = (dump_path / merged_model_name).as_posix()
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path=str(dump_path),
|
path=dump_path,
|
||||||
description=f"Merge of models {', '.join(model_names)}",
|
description=f"Merge of models {', '.join(model_names)}",
|
||||||
model_format="diffusers",
|
model_format="diffusers",
|
||||||
variant=ModelVariantType.Normal.value,
|
variant=ModelVariantType.Normal.value,
|
||||||
|
@ -481,9 +481,19 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
with open(config_file, "r") as file:
|
with open(config_file, "r") as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
# no obvious way to distinguish between sd2-base and sd2-768
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
return (
|
dimension = config["cross_attention_dim"]
|
||||||
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2
|
base_model = (
|
||||||
|
BaseModelType.StableDiffusion1
|
||||||
|
if dimension == 768
|
||||||
|
else BaseModelType.StableDiffusion2
|
||||||
|
if dimension == 1024
|
||||||
|
else BaseModelType.StableDiffusionXL
|
||||||
|
if dimension == 2048
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
if not base_model:
|
||||||
|
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for the web backend.
|
|
||||||
"""
|
|
||||||
from .invoke_ai_web_server import InvokeAIWebServer
|
|
File diff suppressed because it is too large
Load Diff
@ -1,56 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
from ...args import PRECISION_CHOICES
|
|
||||||
|
|
||||||
|
|
||||||
def create_cmd_parser():
|
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI web UI")
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
help="The host to serve on",
|
|
||||||
default="localhost",
|
|
||||||
)
|
|
||||||
parser.add_argument("--port", type=int, help="The port to serve on", default=9090)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cors",
|
|
||||||
nargs="*",
|
|
||||||
type=str,
|
|
||||||
help="Additional allowed origins, comma-separated",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line",
|
|
||||||
)
|
|
||||||
# TODO: Can't get flask to serve images from any dir (saving to the dir does work when specified)
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--output_dir",
|
|
||||||
# default="outputs/",
|
|
||||||
# type=str,
|
|
||||||
# help="Directory for output images",
|
|
||||||
# )
|
|
||||||
parser.add_argument(
|
|
||||||
"-v",
|
|
||||||
"--verbose",
|
|
||||||
action="store_true",
|
|
||||||
help="Enables verbose logging",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precision",
|
|
||||||
dest="precision",
|
|
||||||
type=str,
|
|
||||||
choices=PRECISION_CHOICES,
|
|
||||||
metavar="PRECISION",
|
|
||||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
|
||||||
default="auto",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--free_gpu_mem",
|
|
||||||
dest="free_gpu_mem",
|
|
||||||
action="store_true",
|
|
||||||
help="Force free gpu memory before final decoding",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
@ -1,113 +0,0 @@
|
|||||||
from typing import Literal, Union
|
|
||||||
|
|
||||||
from PIL import Image, ImageChops
|
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
|
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
|
|
||||||
def check_for_any_transparency(img: Union[ImageType, str]) -> bool:
|
|
||||||
if type(img) is str:
|
|
||||||
img = Image.open(str)
|
|
||||||
|
|
||||||
if img.info.get("transparency", None) is not None:
|
|
||||||
return True
|
|
||||||
if img.mode == "P":
|
|
||||||
transparent = img.info.get("transparency", -1)
|
|
||||||
for _, index in img.getcolors():
|
|
||||||
if index == transparent:
|
|
||||||
return True
|
|
||||||
elif img.mode == "RGBA":
|
|
||||||
extrema = img.getextrema()
|
|
||||||
if extrema[3][0] < 255:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_canvas_generation_mode(
|
|
||||||
init_img: Union[ImageType, str], init_mask: Union[ImageType, str]
|
|
||||||
) -> Literal["txt2img", "outpainting", "inpainting", "img2img",]:
|
|
||||||
if type(init_img) is str:
|
|
||||||
init_img = Image.open(init_img)
|
|
||||||
|
|
||||||
if type(init_mask) is str:
|
|
||||||
init_mask = Image.open(init_mask)
|
|
||||||
|
|
||||||
init_img = init_img.convert("RGBA")
|
|
||||||
|
|
||||||
# Get alpha from init_img
|
|
||||||
init_img_alpha = init_img.split()[-1]
|
|
||||||
init_img_alpha_mask = init_img_alpha.convert("L")
|
|
||||||
init_img_has_transparency = check_for_any_transparency(init_img)
|
|
||||||
|
|
||||||
if init_img_has_transparency:
|
|
||||||
init_img_is_fully_transparent = True if init_img_alpha_mask.getbbox() is None else False
|
|
||||||
|
|
||||||
"""
|
|
||||||
Mask images are white in areas where no change should be made, black where changes
|
|
||||||
should be made.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Fit the mask to init_img's size and convert it to greyscale
|
|
||||||
init_mask = init_mask.resize(init_img.size).convert("L")
|
|
||||||
|
|
||||||
"""
|
|
||||||
PIL.Image.getbbox() returns the bounding box of non-zero areas of the image, so we first
|
|
||||||
invert the mask image so that masked areas are white and other areas black == zero.
|
|
||||||
getbbox() now tells us if the are any masked areas.
|
|
||||||
"""
|
|
||||||
init_mask_bbox = ImageChops.invert(init_mask).getbbox()
|
|
||||||
init_mask_exists = False if init_mask_bbox is None else True
|
|
||||||
|
|
||||||
if init_img_has_transparency:
|
|
||||||
if init_img_is_fully_transparent:
|
|
||||||
return "txt2img"
|
|
||||||
else:
|
|
||||||
return "outpainting"
|
|
||||||
else:
|
|
||||||
if init_mask_exists:
|
|
||||||
return "inpainting"
|
|
||||||
else:
|
|
||||||
return "img2img"
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Testing
|
|
||||||
init_img_opaque = "test_images/init-img_opaque.png"
|
|
||||||
init_img_partial_transparency = "test_images/init-img_partial_transparency.png"
|
|
||||||
init_img_full_transparency = "test_images/init-img_full_transparency.png"
|
|
||||||
init_mask_no_mask = "test_images/init-mask_no_mask.png"
|
|
||||||
init_mask_has_mask = "test_images/init-mask_has_mask.png"
|
|
||||||
|
|
||||||
print(
|
|
||||||
"OPAQUE IMAGE, NO MASK, expect img2img, got ",
|
|
||||||
get_canvas_generation_mode(init_img_opaque, init_mask_no_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"IMAGE WITH TRANSPARENCY, NO MASK, expect outpainting, got ",
|
|
||||||
get_canvas_generation_mode(init_img_partial_transparency, init_mask_no_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"FULLY TRANSPARENT IMAGE NO MASK, expect txt2img, got ",
|
|
||||||
get_canvas_generation_mode(init_img_full_transparency, init_mask_no_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"OPAQUE IMAGE, WITH MASK, expect inpainting, got ",
|
|
||||||
get_canvas_generation_mode(init_img_opaque, init_mask_has_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"IMAGE WITH TRANSPARENCY, WITH MASK, expect outpainting, got ",
|
|
||||||
get_canvas_generation_mode(init_img_partial_transparency, init_mask_has_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"FULLY TRANSPARENT IMAGE WITH MASK, expect txt2img, got ",
|
|
||||||
get_canvas_generation_mode(init_img_full_transparency, init_mask_has_mask),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,82 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from .parse_seed_weights import parse_seed_weights
|
|
||||||
|
|
||||||
SAMPLER_CHOICES = [
|
|
||||||
"ddim",
|
|
||||||
"ddpm",
|
|
||||||
"deis",
|
|
||||||
"lms",
|
|
||||||
"lms_k",
|
|
||||||
"pndm",
|
|
||||||
"heun",
|
|
||||||
"heun_k",
|
|
||||||
"euler",
|
|
||||||
"euler_k",
|
|
||||||
"euler_a",
|
|
||||||
"kdpm_2",
|
|
||||||
"kdpm_2_a",
|
|
||||||
"dpmpp_2s",
|
|
||||||
"dpmpp_2s_k",
|
|
||||||
"dpmpp_2m",
|
|
||||||
"dpmpp_2m_k",
|
|
||||||
"dpmpp_2m_sde",
|
|
||||||
"dpmpp_2m_sde_k",
|
|
||||||
"dpmpp_sde",
|
|
||||||
"dpmpp_sde_k",
|
|
||||||
"unipc",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def parameters_to_command(params):
|
|
||||||
"""
|
|
||||||
Converts dict of parameters into a `invoke.py` REPL command.
|
|
||||||
"""
|
|
||||||
|
|
||||||
switches = list()
|
|
||||||
|
|
||||||
if "prompt" in params:
|
|
||||||
switches.append(f'"{params["prompt"]}"')
|
|
||||||
if "steps" in params:
|
|
||||||
switches.append(f'-s {params["steps"]}')
|
|
||||||
if "seed" in params:
|
|
||||||
switches.append(f'-S {params["seed"]}')
|
|
||||||
if "width" in params:
|
|
||||||
switches.append(f'-W {params["width"]}')
|
|
||||||
if "height" in params:
|
|
||||||
switches.append(f'-H {params["height"]}')
|
|
||||||
if "cfg_scale" in params:
|
|
||||||
switches.append(f'-C {params["cfg_scale"]}')
|
|
||||||
if "sampler_name" in params:
|
|
||||||
switches.append(f'-A {params["sampler_name"]}')
|
|
||||||
if "seamless" in params and params["seamless"] == True:
|
|
||||||
switches.append(f"--seamless")
|
|
||||||
if "hires_fix" in params and params["hires_fix"] == True:
|
|
||||||
switches.append(f"--hires")
|
|
||||||
if "init_img" in params and len(params["init_img"]) > 0:
|
|
||||||
switches.append(f'-I {params["init_img"]}')
|
|
||||||
if "init_mask" in params and len(params["init_mask"]) > 0:
|
|
||||||
switches.append(f'-M {params["init_mask"]}')
|
|
||||||
if "init_color" in params and len(params["init_color"]) > 0:
|
|
||||||
switches.append(f'--init_color {params["init_color"]}')
|
|
||||||
if "strength" in params and "init_img" in params:
|
|
||||||
switches.append(f'-f {params["strength"]}')
|
|
||||||
if "fit" in params and params["fit"] == True:
|
|
||||||
switches.append(f"--fit")
|
|
||||||
if "facetool" in params:
|
|
||||||
switches.append(f'-ft {params["facetool"]}')
|
|
||||||
if "facetool_strength" in params and params["facetool_strength"]:
|
|
||||||
switches.append(f'-G {params["facetool_strength"]}')
|
|
||||||
elif "gfpgan_strength" in params and params["gfpgan_strength"]:
|
|
||||||
switches.append(f'-G {params["gfpgan_strength"]}')
|
|
||||||
if "codeformer_fidelity" in params:
|
|
||||||
switches.append(f'-cf {params["codeformer_fidelity"]}')
|
|
||||||
if "upscale" in params and params["upscale"]:
|
|
||||||
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
|
|
||||||
if "variation_amount" in params and params["variation_amount"] > 0:
|
|
||||||
switches.append(f'-v {params["variation_amount"]}')
|
|
||||||
if "with_variations" in params:
|
|
||||||
seed_weight_pairs = ",".join(f"{seed}:{weight}" for seed, weight in params["with_variations"])
|
|
||||||
switches.append(f"-V {seed_weight_pairs}")
|
|
||||||
|
|
||||||
return " ".join(switches)
|
|
@ -1,47 +0,0 @@
|
|||||||
def parse_seed_weights(seed_weights):
|
|
||||||
"""
|
|
||||||
Accepts seed weights as string in "12345:0.1,23456:0.2,3456:0.3" format
|
|
||||||
Validates them
|
|
||||||
If valid: returns as [[12345, 0.1], [23456, 0.2], [3456, 0.3]]
|
|
||||||
If invalid: returns False
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Must be a string
|
|
||||||
if not isinstance(seed_weights, str):
|
|
||||||
return False
|
|
||||||
# String must not be empty
|
|
||||||
if len(seed_weights) == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
pairs = []
|
|
||||||
|
|
||||||
for pair in seed_weights.split(","):
|
|
||||||
split_values = pair.split(":")
|
|
||||||
|
|
||||||
# Seed and weight are required
|
|
||||||
if len(split_values) != 2:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if len(split_values[0]) == 0 or len(split_values[1]) == 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Try casting the seed to int and weight to float
|
|
||||||
try:
|
|
||||||
seed = int(split_values[0])
|
|
||||||
weight = float(split_values[1])
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Seed must be 0 or above
|
|
||||||
if not seed >= 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Weight must be between 0 and 1
|
|
||||||
if not (weight >= 0 and weight <= 1):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# This pair is valid
|
|
||||||
pairs.append([seed, weight])
|
|
||||||
|
|
||||||
# All pairs are valid
|
|
||||||
return pairs
|
|
Binary file not shown.
Before Width: | Height: | Size: 2.7 KiB |
Binary file not shown.
Before Width: | Height: | Size: 292 KiB |
Binary file not shown.
Before Width: | Height: | Size: 164 KiB |
Binary file not shown.
Before Width: | Height: | Size: 9.5 KiB |
Binary file not shown.
Before Width: | Height: | Size: 3.4 KiB |
169
invokeai/frontend/web/dist/assets/App-0a099278.js
vendored
169
invokeai/frontend/web/dist/assets/App-0a099278.js
vendored
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-7d912410.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-7d912410.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-deaa1f26.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-b4489359.js";var za=String.raw,Ca=za`
|
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-2c171c8f.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-971c0572.js";var za=String.raw,Ca=za`
|
||||||
:root,
|
:root,
|
||||||
:host {
|
:host {
|
||||||
--chakra-vh: 100vh;
|
--chakra-vh: 100vh;
|
151
invokeai/frontend/web/dist/assets/index-2c171c8f.js
vendored
Normal file
151
invokeai/frontend/web/dist/assets/index-2c171c8f.js
vendored
Normal file
File diff suppressed because one or more lines are too long
151
invokeai/frontend/web/dist/assets/index-deaa1f26.js
vendored
151
invokeai/frontend/web/dist/assets/index-deaa1f26.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-deaa1f26.js"></script>
|
<script type="module" crossorigin src="./assets/index-2c171c8f.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@ -503,6 +503,9 @@
|
|||||||
"hiresStrength": "High Res Strength",
|
"hiresStrength": "High Res Strength",
|
||||||
"imageFit": "Fit Initial Image To Output Size",
|
"imageFit": "Fit Initial Image To Output Size",
|
||||||
"codeformerFidelity": "Fidelity",
|
"codeformerFidelity": "Fidelity",
|
||||||
|
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||||
|
"maskBlur": "Mask Blur",
|
||||||
|
"maskBlurMethod": "Mask Blur Method",
|
||||||
"seamSize": "Seam Size",
|
"seamSize": "Seam Size",
|
||||||
"seamBlur": "Seam Blur",
|
"seamBlur": "Seam Blur",
|
||||||
"seamStrength": "Seam Strength",
|
"seamStrength": "Seam Strength",
|
||||||
|
@ -61,6 +61,7 @@
|
|||||||
"@dagrejs/graphlib": "^2.1.13",
|
"@dagrejs/graphlib": "^2.1.13",
|
||||||
"@dnd-kit/core": "^6.0.8",
|
"@dnd-kit/core": "^6.0.8",
|
||||||
"@dnd-kit/modifiers": "^6.0.1",
|
"@dnd-kit/modifiers": "^6.0.1",
|
||||||
|
"@dnd-kit/utilities": "^3.2.1",
|
||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/styled": "^11.11.0",
|
"@emotion/styled": "^11.11.0",
|
||||||
"@floating-ui/react-dom": "^2.0.1",
|
"@floating-ui/react-dom": "^2.0.1",
|
||||||
|
34
invokeai/frontend/web/scripts/colors.js
Normal file
34
invokeai/frontend/web/scripts/colors.js
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
export const COLORS = {
|
||||||
|
reset: '\x1b[0m',
|
||||||
|
bright: '\x1b[1m',
|
||||||
|
dim: '\x1b[2m',
|
||||||
|
underscore: '\x1b[4m',
|
||||||
|
blink: '\x1b[5m',
|
||||||
|
reverse: '\x1b[7m',
|
||||||
|
hidden: '\x1b[8m',
|
||||||
|
|
||||||
|
fg: {
|
||||||
|
black: '\x1b[30m',
|
||||||
|
red: '\x1b[31m',
|
||||||
|
green: '\x1b[32m',
|
||||||
|
yellow: '\x1b[33m',
|
||||||
|
blue: '\x1b[34m',
|
||||||
|
magenta: '\x1b[35m',
|
||||||
|
cyan: '\x1b[36m',
|
||||||
|
white: '\x1b[37m',
|
||||||
|
gray: '\x1b[90m',
|
||||||
|
crimson: '\x1b[38m',
|
||||||
|
},
|
||||||
|
bg: {
|
||||||
|
black: '\x1b[40m',
|
||||||
|
red: '\x1b[41m',
|
||||||
|
green: '\x1b[42m',
|
||||||
|
yellow: '\x1b[43m',
|
||||||
|
blue: '\x1b[44m',
|
||||||
|
magenta: '\x1b[45m',
|
||||||
|
cyan: '\x1b[46m',
|
||||||
|
white: '\x1b[47m',
|
||||||
|
gray: '\x1b[100m',
|
||||||
|
crimson: '\x1b[48m',
|
||||||
|
},
|
||||||
|
};
|
@ -1,23 +1,83 @@
|
|||||||
import fs from 'node:fs';
|
import fs from 'node:fs';
|
||||||
import openapiTS from 'openapi-typescript';
|
import openapiTS from 'openapi-typescript';
|
||||||
|
import { COLORS } from './colors.js';
|
||||||
|
|
||||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||||
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
||||||
|
|
||||||
async function main() {
|
async function main() {
|
||||||
process.stdout.write(
|
process.stdout.write(
|
||||||
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
|
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n`
|
||||||
);
|
);
|
||||||
const types = await openapiTS(OPENAPI_URL, {
|
const types = await openapiTS(OPENAPI_URL, {
|
||||||
exportType: true,
|
exportType: true,
|
||||||
transform: (schemaObject) => {
|
transform: (schemaObject, metadata) => {
|
||||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Because invocations may have required fields that accept connection input, the generated
|
||||||
|
* types may be incorrect.
|
||||||
|
*
|
||||||
|
* For example, the ImageResizeInvocation has a required `image` field, but because it accepts
|
||||||
|
* connection input, it should be optional on instantiation of the field.
|
||||||
|
*
|
||||||
|
* To handle this, the schema exposes an `input` property that can be used to determine if the
|
||||||
|
* field accepts connection input. If it does, we can make the field optional.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Check if we are generating types for an invocation
|
||||||
|
const isInvocationPath = metadata.path.match(
|
||||||
|
/^#\/components\/schemas\/\w*Invocation$/
|
||||||
|
);
|
||||||
|
|
||||||
|
const hasInvocationProperties =
|
||||||
|
schemaObject.properties &&
|
||||||
|
['id', 'is_intermediate', 'type'].every(
|
||||||
|
(prop) => prop in schemaObject.properties
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isInvocationPath && hasInvocationProperties) {
|
||||||
|
// We only want to make fields optional if they are required
|
||||||
|
if (!Array.isArray(schemaObject?.required)) {
|
||||||
|
schemaObject.required = ['id', 'type'];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
schemaObject.required.forEach((prop) => {
|
||||||
|
const acceptsConnection = ['any', 'connection'].includes(
|
||||||
|
schemaObject.properties?.[prop]?.['input']
|
||||||
|
);
|
||||||
|
|
||||||
|
if (acceptsConnection) {
|
||||||
|
// remove this prop from the required array
|
||||||
|
const invocationName = metadata.path.split('/').pop();
|
||||||
|
console.log(
|
||||||
|
`Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}`
|
||||||
|
);
|
||||||
|
schemaObject.required = schemaObject.required.filter(
|
||||||
|
(r) => r !== prop
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
schemaObject.required = [
|
||||||
|
...new Set(schemaObject.required.concat(['id', 'type'])),
|
||||||
|
];
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if (
|
||||||
|
// 'input' in schemaObject &&
|
||||||
|
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
|
||||||
|
// ) {
|
||||||
|
// schemaObject.required = false;
|
||||||
|
// }
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
fs.writeFileSync(OUTPUT_FILE, types);
|
fs.writeFileSync(OUTPUT_FILE, types);
|
||||||
process.stdout.write(` OK!\r\n`);
|
process.stdout.write(`\nOK!\r\n`);
|
||||||
}
|
}
|
||||||
|
|
||||||
main();
|
main();
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import {
|
||||||
|
ctrlKeyPressed,
|
||||||
|
metaKeyPressed,
|
||||||
|
shiftKeyPressed,
|
||||||
|
} from 'features/ui/store/hotkeysSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
setActiveTab,
|
||||||
@ -16,11 +20,11 @@ import React, { memo } from 'react';
|
|||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
const globalHotkeysSelector = createSelector(
|
||||||
[(state: RootState) => state.hotkeys, (state: RootState) => state.ui],
|
[stateSelector],
|
||||||
(hotkeys, ui) => {
|
({ hotkeys, ui }) => {
|
||||||
const { shift } = hotkeys;
|
const { shift, ctrl, meta } = hotkeys;
|
||||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||||
return { shift, shouldPinGallery, shouldPinParametersPanel };
|
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector(
|
|||||||
*/
|
*/
|
||||||
const GlobalHotkeys: React.FC = () => {
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector(
|
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
|
||||||
globalHotkeysSelector
|
useAppSelector(globalHotkeysSelector);
|
||||||
);
|
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => {
|
|||||||
} else {
|
} else {
|
||||||
shift && dispatch(shiftKeyPressed(false));
|
shift && dispatch(shiftKeyPressed(false));
|
||||||
}
|
}
|
||||||
|
if (isHotkeyPressed('ctrl')) {
|
||||||
|
!ctrl && dispatch(ctrlKeyPressed(true));
|
||||||
|
} else {
|
||||||
|
ctrl && dispatch(ctrlKeyPressed(false));
|
||||||
|
}
|
||||||
|
if (isHotkeyPressed('meta')) {
|
||||||
|
!meta && dispatch(metaKeyPressed(true));
|
||||||
|
} else {
|
||||||
|
meta && dispatch(metaKeyPressed(false));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{ keyup: true, keydown: true },
|
{ keyup: true, keydown: true },
|
||||||
[shift]
|
[shift, ctrl, meta]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys('o', () => {
|
useHotkeys('o', () => {
|
||||||
|
@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
|||||||
import { socketMiddleware } from 'services/events/middleware';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import Loading from '../../common/components/Loading/Loading';
|
||||||
import '../../i18n';
|
import '../../i18n';
|
||||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -80,9 +80,9 @@ const InvokeAIUI = ({
|
|||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<ImageDndContext>
|
<AppDndContext>
|
||||||
<App config={config} headerComponent={headerComponent} />
|
<App config={config} headerComponent={headerComponent} />
|
||||||
</ImageDndContext>
|
</AppDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</Provider>
|
</Provider>
|
||||||
|
@ -19,7 +19,8 @@ type LoggerNamespace =
|
|||||||
| 'nodes'
|
| 'nodes'
|
||||||
| 'system'
|
| 'system'
|
||||||
| 'socketio'
|
| 'socketio'
|
||||||
| 'session';
|
| 'session'
|
||||||
|
| 'dnd';
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) =>
|
export const logger = (namespace: LoggerNamespace) =>
|
||||||
$logger.get().child({ namespace });
|
$logger.get().child({ namespace });
|
||||||
|
@ -15,7 +15,7 @@ export const actionsDenylist = [
|
|||||||
'socket/socketGeneratorProgress',
|
'socket/socketGeneratorProgress',
|
||||||
'socket/appSocketGeneratorProgress',
|
'socket/appSocketGeneratorProgress',
|
||||||
// every time user presses shift
|
// every time user presses shift
|
||||||
'hotkeys/shiftKeyPressed',
|
// 'hotkeys/shiftKeyPressed',
|
||||||
// this happens after every state change
|
// this happens after every state change
|
||||||
'@@REMEMBER_PERSISTED',
|
'@@REMEMBER_PERSISTED',
|
||||||
];
|
];
|
||||||
|
@ -15,6 +15,7 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
|
|||||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||||
|
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
|
||||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||||
@ -27,8 +28,8 @@ import {
|
|||||||
addImageDeletedFulfilledListener,
|
addImageDeletedFulfilledListener,
|
||||||
addImageDeletedPendingListener,
|
addImageDeletedPendingListener,
|
||||||
addImageDeletedRejectedListener,
|
addImageDeletedRejectedListener,
|
||||||
addRequestedSingleImageDeletionListener,
|
|
||||||
addRequestedMultipleImageDeletionListener,
|
addRequestedMultipleImageDeletionListener,
|
||||||
|
addRequestedSingleImageDeletionListener,
|
||||||
} from './listeners/imageDeleted';
|
} from './listeners/imageDeleted';
|
||||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||||
import {
|
import {
|
||||||
@ -79,6 +80,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addImagesStarredListener } from './listeners/imagesStarred';
|
||||||
|
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -120,6 +123,10 @@ addImageDeletedRejectedListener();
|
|||||||
addDeleteBoardAndImagesFulfilledListener();
|
addDeleteBoardAndImagesFulfilledListener();
|
||||||
addImageToDeleteSelectedListener();
|
addImageToDeleteSelectedListener();
|
||||||
|
|
||||||
|
// Image starred
|
||||||
|
addImagesStarredListener();
|
||||||
|
addImagesUnstarredListener();
|
||||||
|
|
||||||
// User Invoked
|
// User Invoked
|
||||||
addUserInvokedCanvasListener();
|
addUserInvokedCanvasListener();
|
||||||
addUserInvokedNodesListener();
|
addUserInvokedNodesListener();
|
||||||
@ -129,6 +136,7 @@ addSessionReadyToInvokeListener();
|
|||||||
|
|
||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
|
addCanvasMaskSavedToGalleryListener();
|
||||||
addCanvasDownloadedAsImageListener();
|
addCanvasDownloadedAsImageListener();
|
||||||
addCanvasCopiedToClipboardListener();
|
addCanvasCopiedToClipboardListener();
|
||||||
addCanvasMergedListener();
|
addCanvasMergedListener();
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
||||||
|
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addCanvasMaskSavedToGalleryListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasMaskSavedToGallery,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('canvas');
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const canvasBlobsAndImageData = await getCanvasData(
|
||||||
|
state.canvas.layerState,
|
||||||
|
state.canvas.boundingBoxCoordinates,
|
||||||
|
state.canvas.boundingBoxDimensions,
|
||||||
|
state.canvas.isMaskEnabled,
|
||||||
|
state.canvas.shouldPreserveMaskedArea
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!canvasBlobsAndImageData) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { maskBlob } = canvasBlobsAndImageData;
|
||||||
|
|
||||||
|
if (!maskBlob) {
|
||||||
|
log.error('Problem getting mask layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Saving Mask',
|
||||||
|
description: 'Unable to export mask',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { autoAddBoardId } = state.gallery;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({
|
||||||
|
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
|
image_category: 'mask',
|
||||||
|
is_intermediate: false,
|
||||||
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
|
crop_visible: true,
|
||||||
|
postUploadAction: {
|
||||||
|
type: 'TOAST',
|
||||||
|
toastOptions: { title: 'Mask Saved to Assets' },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,16 +1,20 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import {
|
||||||
|
fieldImageValueChanged,
|
||||||
|
workflowExposedFieldAdded,
|
||||||
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '../';
|
import { startAppListening } from '../';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
|
||||||
export const dndDropped = createAction<{
|
export const dndDropped = createAction<{
|
||||||
overData: TypesafeDroppableData;
|
overData: TypesafeDroppableData;
|
||||||
@ -21,7 +25,7 @@ export const addImageDroppedListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: dndDropped,
|
actionCreator: dndDropped,
|
||||||
effect: async (action, { dispatch }) => {
|
effect: async (action, { dispatch }) => {
|
||||||
const log = logger('images');
|
const log = logger('dnd');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
|
||||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
@ -31,10 +35,28 @@ export const addImageDroppedListener = () => {
|
|||||||
{ activeData, overData },
|
{ activeData, overData },
|
||||||
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||||
);
|
);
|
||||||
|
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||||
|
log.debug(
|
||||||
|
{ activeData: parseify(activeData), overData: parseify(overData) },
|
||||||
|
'Node field dropped'
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
overData.actionType === 'ADD_FIELD_TO_LINEAR' &&
|
||||||
|
activeData.payloadType === 'NODE_FIELD'
|
||||||
|
) {
|
||||||
|
const { nodeId, field } = activeData.payload;
|
||||||
|
dispatch(
|
||||||
|
workflowExposedFieldAdded({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on current image
|
* Image dropped on current image
|
||||||
*/
|
*/
|
||||||
@ -99,7 +121,7 @@ export const addImageDroppedListener = () => {
|
|||||||
) {
|
) {
|
||||||
const { fieldName, nodeId } = overData.context;
|
const { fieldName, nodeId } = overData.context;
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldImageValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName,
|
fieldName,
|
||||||
value: activeData.payload.imageDTO,
|
value: activeData.payload.imageDTO,
|
||||||
|
@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
|
|
||||||
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
||||||
const { nodeId, fieldName } = postUploadAction;
|
const { nodeId, fieldName } = postUploadAction;
|
||||||
dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO }));
|
dispatch(
|
||||||
|
fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })
|
||||||
|
);
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast({
|
addToast({
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
...DEFAULT_UPLOADED_TOAST,
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
|
||||||
|
import { ImageDTO } from '../../../../../services/api/types';
|
||||||
|
|
||||||
|
export const addImagesStarredListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { updated_image_names: starredImages } = action.payload;
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const { selection } = state.gallery;
|
||||||
|
const updatedSelection: ImageDTO[] = [];
|
||||||
|
|
||||||
|
selection.forEach((selectedImageDTO) => {
|
||||||
|
if (starredImages.includes(selectedImageDTO.image_name)) {
|
||||||
|
updatedSelection.push({
|
||||||
|
...selectedImageDTO,
|
||||||
|
starred: true,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
updatedSelection.push(selectedImageDTO);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
dispatch(selectionChanged(updatedSelection));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,30 @@
|
|||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { selectionChanged } from '../../../../../features/gallery/store/gallerySlice';
|
||||||
|
import { ImageDTO } from '../../../../../services/api/types';
|
||||||
|
|
||||||
|
export const addImagesUnstarredListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { updated_image_names: unstarredImages } = action.payload;
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const { selection } = state.gallery;
|
||||||
|
const updatedSelection: ImageDTO[] = [];
|
||||||
|
|
||||||
|
selection.forEach((selectedImageDTO) => {
|
||||||
|
if (unstarredImages.includes(selectedImageDTO.image_name)) {
|
||||||
|
updatedSelection.push({
|
||||||
|
...selectedImageDTO,
|
||||||
|
starred: false,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
updatedSelection.push(selectedImageDTO);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
dispatch(selectionChanged(updatedSelection));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -15,12 +15,21 @@ import {
|
|||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
import {
|
||||||
|
mainModelsAdapter,
|
||||||
|
modelsApi,
|
||||||
|
vaeModelsAdapter,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { TypeGuardFor } from 'services/api/types';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (state, action) =>
|
predicate: (
|
||||||
|
action
|
||||||
|
): action is TypeGuardFor<
|
||||||
|
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||||
|
> =>
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const currentModel = getState().generation.model;
|
const currentModel = getState().generation.model;
|
||||||
|
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||||
|
|
||||||
const isCurrentModelAvailable = some(
|
if (models.length === 0) {
|
||||||
action.payload.entities,
|
|
||||||
(m) =>
|
|
||||||
m?.model_name === currentModel?.model_name &&
|
|
||||||
m?.base_model === currentModel?.base_model &&
|
|
||||||
m?.model_type === currentModel?.model_type
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
|
||||||
|
|
||||||
if (!firstModel) {
|
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zMainOrOnnxModel.safeParse(firstModel);
|
const isCurrentModelAvailable = currentModel
|
||||||
|
? models.some(
|
||||||
|
(m) =>
|
||||||
|
m.model_name === currentModel.model_name &&
|
||||||
|
m.base_model === currentModel.base_model &&
|
||||||
|
m.model_type === currentModel.model_type
|
||||||
|
)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zMainOrOnnxModel.safeParse(models[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error(
|
log.error(
|
||||||
@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (state, action) =>
|
predicate: (
|
||||||
|
action
|
||||||
|
): action is TypeGuardFor<
|
||||||
|
typeof modelsApi.endpoints.getMainModels.matchFulfilled
|
||||||
|
> =>
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const currentModel = getState().sdxl.refinerModel;
|
const currentModel = getState().sdxl.refinerModel;
|
||||||
|
const models = mainModelsAdapter.getSelectors().selectAll(action.payload);
|
||||||
|
|
||||||
const isCurrentModelAvailable = some(
|
if (models.length === 0) {
|
||||||
action.payload.entities,
|
|
||||||
(m) =>
|
|
||||||
m?.model_name === currentModel?.model_name &&
|
|
||||||
m?.base_model === currentModel?.base_model &&
|
|
||||||
m?.model_type === currentModel?.model_type
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
|
||||||
|
|
||||||
if (!firstModel) {
|
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
dispatch(setShouldUseSDXLRefiner(false));
|
dispatch(setShouldUseSDXLRefiner(false));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zSDXLRefinerModel.safeParse(firstModel);
|
const isCurrentModelAvailable = currentModel
|
||||||
|
? models.some(
|
||||||
|
(m) =>
|
||||||
|
m.model_name === currentModel.model_name &&
|
||||||
|
m.base_model === currentModel.base_model &&
|
||||||
|
m.model_type === currentModel.model_type
|
||||||
|
)
|
||||||
|
: false;
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zSDXLRefinerModel.safeParse(models[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error(
|
log.error(
|
||||||
|
@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => {
|
|||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
const schemaJSON = action.payload;
|
const schemaJSON = action.payload;
|
||||||
|
|
||||||
log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema');
|
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||||
|
|
||||||
const nodeTemplates = parseSchema(schemaJSON);
|
const nodeTemplates = parseSchema(schemaJSON);
|
||||||
|
|
||||||
@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => {
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: receivedOpenAPISchema.rejected,
|
actionCreator: receivedOpenAPISchema.rejected,
|
||||||
effect: () => {
|
effect: (action) => {
|
||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
log.error('Problem dereferencing OpenAPI Schema');
|
log.error(
|
||||||
|
{ error: parseify(action.error) },
|
||||||
|
'Problem retrieving OpenAPI Schema'
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -19,7 +19,7 @@ import {
|
|||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
const nodeDenylist = ['dataURL_image'];
|
const nodeDenylist = ['load_image'];
|
||||||
|
|
||||||
export const addInvocationCompleteEventListener = () => {
|
export const addInvocationCompleteEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => {
|
|||||||
const log = logger('session');
|
const log = logger('session');
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildNodesGraph(state);
|
const graph = buildNodesGraph(state.nodes);
|
||||||
dispatch(nodesGraphBuilt(graph));
|
dispatch(nodesGraphBuilt(graph));
|
||||||
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
||||||
|
|
||||||
|
@ -1,86 +1,7 @@
|
|||||||
import {
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
// CONTROLNET_MODELS,
|
|
||||||
CONTROLNET_PROCESSORS,
|
|
||||||
} from 'features/controlNet/store/constants';
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
// These are old types from the model management UI
|
|
||||||
|
|
||||||
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
|
||||||
|
|
||||||
// export type Model = {
|
|
||||||
// status: ModelStatus;
|
|
||||||
// description: string;
|
|
||||||
// weights: string;
|
|
||||||
// config?: string;
|
|
||||||
// vae?: string;
|
|
||||||
// width?: number;
|
|
||||||
// height?: number;
|
|
||||||
// default?: boolean;
|
|
||||||
// format?: string;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type DiffusersModel = {
|
|
||||||
// status: ModelStatus;
|
|
||||||
// description: string;
|
|
||||||
// repo_id?: string;
|
|
||||||
// path?: string;
|
|
||||||
// vae?: {
|
|
||||||
// repo_id?: string;
|
|
||||||
// path?: string;
|
|
||||||
// };
|
|
||||||
// format?: string;
|
|
||||||
// default?: boolean;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type ModelList = Record<string, Model & DiffusersModel>;
|
|
||||||
|
|
||||||
// export type FoundModel = {
|
|
||||||
// name: string;
|
|
||||||
// location: string;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelConfigProps = {
|
|
||||||
// name: string | undefined;
|
|
||||||
// description: string | undefined;
|
|
||||||
// config: string | undefined;
|
|
||||||
// weights: string | undefined;
|
|
||||||
// vae: string | undefined;
|
|
||||||
// width: number | undefined;
|
|
||||||
// height: number | undefined;
|
|
||||||
// default: boolean | undefined;
|
|
||||||
// format: string | undefined;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeDiffusersModelConfigProps = {
|
|
||||||
// name: string | undefined;
|
|
||||||
// description: string | undefined;
|
|
||||||
// repo_id: string | undefined;
|
|
||||||
// path: string | undefined;
|
|
||||||
// default: boolean | undefined;
|
|
||||||
// format: string | undefined;
|
|
||||||
// vae: {
|
|
||||||
// repo_id: string | undefined;
|
|
||||||
// path: string | undefined;
|
|
||||||
// };
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelConversionProps = {
|
|
||||||
// model_name: string;
|
|
||||||
// save_location: string;
|
|
||||||
// custom_location: string | null;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type InvokeModelMergingProps = {
|
|
||||||
// models_to_merge: string[];
|
|
||||||
// alpha: number;
|
|
||||||
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
|
||||||
// force: boolean;
|
|
||||||
// merged_model_name: string;
|
|
||||||
// model_merge_save_path: string | null;
|
|
||||||
// };
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A disable-able application feature
|
* A disable-able application feature
|
||||||
*/
|
*/
|
||||||
|
126
invokeai/frontend/web/src/common/components/IAIContextMenu.tsx
Normal file
126
invokeai/frontend/web/src/common/components/IAIContextMenu.tsx
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
/**
|
||||||
|
* This is a copy-paste of https://github.com/lukasbach/chakra-ui-contextmenu with a small change.
|
||||||
|
*
|
||||||
|
* The reactflow background element somehow prevents the chakra `useOutsideClick()` hook from working.
|
||||||
|
* With a menu open, clicking on the reactflow background element doesn't close the menu.
|
||||||
|
*
|
||||||
|
* Reactflow does provide an `onPaneClick` to handle clicks on the background element, but it is not
|
||||||
|
* straightforward to programatically close the menu.
|
||||||
|
*
|
||||||
|
* As a (hopefully temporary) workaround, we will use a dirty hack:
|
||||||
|
* - create `globalContextMenuCloseTrigger: number` in `ui` slice
|
||||||
|
* - increment it in `onPaneClick`
|
||||||
|
* - `useEffect()` to close the menu when `globalContextMenuCloseTrigger` changes
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuButtonProps,
|
||||||
|
MenuProps,
|
||||||
|
Portal,
|
||||||
|
PortalProps,
|
||||||
|
useEventListener,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import * as React from 'react';
|
||||||
|
import {
|
||||||
|
MutableRefObject,
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
useRef,
|
||||||
|
useState,
|
||||||
|
} from 'react';
|
||||||
|
|
||||||
|
export interface IAIContextMenuProps<T extends HTMLElement> {
|
||||||
|
renderMenu: () => JSX.Element | null;
|
||||||
|
children: (ref: MutableRefObject<T | null>) => JSX.Element | null;
|
||||||
|
menuProps?: Omit<MenuProps, 'children'> & { children?: React.ReactNode };
|
||||||
|
portalProps?: Omit<PortalProps, 'children'> & { children?: React.ReactNode };
|
||||||
|
menuButtonProps?: MenuButtonProps;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function IAIContextMenu<T extends HTMLElement = HTMLElement>(
|
||||||
|
props: IAIContextMenuProps<T>
|
||||||
|
) {
|
||||||
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
|
const [isRendered, setIsRendered] = useState(false);
|
||||||
|
const [isDeferredOpen, setIsDeferredOpen] = useState(false);
|
||||||
|
const [position, setPosition] = useState<[number, number]>([0, 0]);
|
||||||
|
const targetRef = useRef<T>(null);
|
||||||
|
|
||||||
|
const globalContextMenuCloseTrigger = useAppSelector(
|
||||||
|
(state) => state.ui.globalContextMenuCloseTrigger
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOpen) {
|
||||||
|
setTimeout(() => {
|
||||||
|
setIsRendered(true);
|
||||||
|
setTimeout(() => {
|
||||||
|
setIsDeferredOpen(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setIsDeferredOpen(false);
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
setIsRendered(isOpen);
|
||||||
|
}, 1000);
|
||||||
|
return () => clearTimeout(timeout);
|
||||||
|
}
|
||||||
|
}, [isOpen]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setIsOpen(false);
|
||||||
|
setIsDeferredOpen(false);
|
||||||
|
setIsRendered(false);
|
||||||
|
}, [globalContextMenuCloseTrigger]);
|
||||||
|
|
||||||
|
useEventListener('contextmenu', (e) => {
|
||||||
|
if (
|
||||||
|
targetRef.current?.contains(e.target as HTMLElement) ||
|
||||||
|
e.target === targetRef.current
|
||||||
|
) {
|
||||||
|
e.preventDefault();
|
||||||
|
setIsOpen(true);
|
||||||
|
setPosition([e.pageX, e.pageY]);
|
||||||
|
} else {
|
||||||
|
setIsOpen(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const onCloseHandler = useCallback(() => {
|
||||||
|
props.menuProps?.onClose?.();
|
||||||
|
setIsOpen(false);
|
||||||
|
}, [props.menuProps]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{props.children(targetRef)}
|
||||||
|
{isRendered && (
|
||||||
|
<Portal {...props.portalProps}>
|
||||||
|
<Menu
|
||||||
|
isOpen={isDeferredOpen}
|
||||||
|
gutter={0}
|
||||||
|
{...props.menuProps}
|
||||||
|
onClose={onCloseHandler}
|
||||||
|
>
|
||||||
|
<MenuButton
|
||||||
|
aria-hidden={true}
|
||||||
|
w={1}
|
||||||
|
h={1}
|
||||||
|
style={{
|
||||||
|
position: 'absolute',
|
||||||
|
left: position[0],
|
||||||
|
top: position[1],
|
||||||
|
cursor: 'default',
|
||||||
|
}}
|
||||||
|
{...props.menuButtonProps}
|
||||||
|
/>
|
||||||
|
{props.renderMenu()}
|
||||||
|
</Menu>
|
||||||
|
</Portal>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
@ -1,16 +1,11 @@
|
|||||||
import {
|
import {
|
||||||
ChakraProps,
|
ChakraProps,
|
||||||
Flex,
|
Flex,
|
||||||
|
FlexProps,
|
||||||
Icon,
|
Icon,
|
||||||
Image,
|
Image,
|
||||||
useColorMode,
|
useColorMode,
|
||||||
useColorModeValue,
|
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import {
|
import {
|
||||||
IAILoadingImageFallback,
|
IAILoadingImageFallback,
|
||||||
IAINoContentFallback,
|
IAINoContentFallback,
|
||||||
@ -21,27 +16,39 @@ import ImageContextMenu from 'features/gallery/components/ImageContextMenu/Image
|
|||||||
import {
|
import {
|
||||||
MouseEvent,
|
MouseEvent,
|
||||||
ReactElement,
|
ReactElement,
|
||||||
|
ReactNode,
|
||||||
SyntheticEvent,
|
SyntheticEvent,
|
||||||
memo,
|
memo,
|
||||||
useCallback,
|
useCallback,
|
||||||
useState,
|
useState,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
import { FaImage, FaUpload } from 'react-icons/fa';
|
||||||
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
import { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
import IAIDraggable from './IAIDraggable';
|
import IAIDraggable from './IAIDraggable';
|
||||||
import IAIDroppable from './IAIDroppable';
|
import IAIDroppable from './IAIDroppable';
|
||||||
import SelectionOverlay from './SelectionOverlay';
|
import SelectionOverlay from './SelectionOverlay';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
|
|
||||||
type IAIDndImageProps = {
|
const defaultUploadElement = (
|
||||||
|
<Icon
|
||||||
|
as={FaUpload}
|
||||||
|
sx={{
|
||||||
|
boxSize: 16,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
|
||||||
|
const defaultNoContentFallback = <IAINoContentFallback icon={FaImage} />;
|
||||||
|
|
||||||
|
type IAIDndImageProps = FlexProps & {
|
||||||
imageDTO: ImageDTO | undefined;
|
imageDTO: ImageDTO | undefined;
|
||||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||||
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
||||||
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
|
|
||||||
withResetIcon?: boolean;
|
|
||||||
resetIcon?: ReactElement;
|
|
||||||
resetTooltip?: string;
|
|
||||||
withMetadataOverlay?: boolean;
|
withMetadataOverlay?: boolean;
|
||||||
isDragDisabled?: boolean;
|
isDragDisabled?: boolean;
|
||||||
isDropDisabled?: boolean;
|
isDropDisabled?: boolean;
|
||||||
@ -52,21 +59,21 @@ type IAIDndImageProps = {
|
|||||||
fitContainer?: boolean;
|
fitContainer?: boolean;
|
||||||
droppableData?: TypesafeDroppableData;
|
droppableData?: TypesafeDroppableData;
|
||||||
draggableData?: TypesafeDraggableData;
|
draggableData?: TypesafeDraggableData;
|
||||||
dropLabel?: string;
|
dropLabel?: ReactNode;
|
||||||
isSelected?: boolean;
|
isSelected?: boolean;
|
||||||
thumbnail?: boolean;
|
thumbnail?: boolean;
|
||||||
noContentFallback?: ReactElement;
|
noContentFallback?: ReactElement;
|
||||||
useThumbailFallback?: boolean;
|
useThumbailFallback?: boolean;
|
||||||
withHoverOverlay?: boolean;
|
withHoverOverlay?: boolean;
|
||||||
|
children?: JSX.Element;
|
||||||
|
uploadElement?: ReactNode;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIDndImage = (props: IAIDndImageProps) => {
|
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||||
const {
|
const {
|
||||||
imageDTO,
|
imageDTO,
|
||||||
onClickReset,
|
|
||||||
onError,
|
onError,
|
||||||
onClick,
|
onClick,
|
||||||
withResetIcon = false,
|
|
||||||
withMetadataOverlay = false,
|
withMetadataOverlay = false,
|
||||||
isDropDisabled = false,
|
isDropDisabled = false,
|
||||||
isDragDisabled = false,
|
isDragDisabled = false,
|
||||||
@ -80,32 +87,37 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
dropLabel,
|
dropLabel,
|
||||||
isSelected = false,
|
isSelected = false,
|
||||||
thumbnail = false,
|
thumbnail = false,
|
||||||
resetTooltip = 'Reset',
|
noContentFallback = defaultNoContentFallback,
|
||||||
resetIcon = <FaUndo />,
|
uploadElement = defaultUploadElement,
|
||||||
noContentFallback = <IAINoContentFallback icon={FaImage} />,
|
|
||||||
useThumbailFallback,
|
useThumbailFallback,
|
||||||
withHoverOverlay = false,
|
withHoverOverlay = false,
|
||||||
|
children,
|
||||||
|
onMouseOver,
|
||||||
|
onMouseOut,
|
||||||
} = props;
|
} = props;
|
||||||
|
|
||||||
const { colorMode } = useColorMode();
|
const { colorMode } = useColorMode();
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
const handleMouseOver = useCallback(() => {
|
const handleMouseOver = useCallback(
|
||||||
setIsHovered(true);
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
}, []);
|
if (onMouseOver) onMouseOver(e);
|
||||||
const handleMouseOut = useCallback(() => {
|
setIsHovered(true);
|
||||||
setIsHovered(false);
|
},
|
||||||
}, []);
|
[onMouseOver]
|
||||||
|
);
|
||||||
|
const handleMouseOut = useCallback(
|
||||||
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
if (onMouseOut) onMouseOut(e);
|
||||||
|
setIsHovered(false);
|
||||||
|
},
|
||||||
|
[onMouseOut]
|
||||||
|
);
|
||||||
|
|
||||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||||
postUploadAction,
|
postUploadAction,
|
||||||
isDisabled: isUploadDisabled,
|
isDisabled: isUploadDisabled,
|
||||||
});
|
});
|
||||||
|
|
||||||
const resetIconShadow = useColorModeValue(
|
|
||||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
|
|
||||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
|
|
||||||
);
|
|
||||||
|
|
||||||
const uploadButtonStyles = isUploadDisabled
|
const uploadButtonStyles = isUploadDisabled
|
||||||
? {}
|
? {}
|
||||||
: {
|
: {
|
||||||
@ -157,11 +169,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
<IAILoadingImageFallback image={imageDTO} />
|
<IAILoadingImageFallback image={imageDTO} />
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
width={imageDTO.width}
|
|
||||||
height={imageDTO.height}
|
|
||||||
onError={onError}
|
onError={onError}
|
||||||
draggable={false}
|
draggable={false}
|
||||||
sx={{
|
sx={{
|
||||||
|
w: imageDTO.width,
|
||||||
objectFit: 'contain',
|
objectFit: 'contain',
|
||||||
maxW: 'full',
|
maxW: 'full',
|
||||||
maxH: 'full',
|
maxH: 'full',
|
||||||
@ -196,12 +207,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
{...getUploadButtonProps()}
|
{...getUploadButtonProps()}
|
||||||
>
|
>
|
||||||
<input {...getUploadInputProps()} />
|
<input {...getUploadInputProps()} />
|
||||||
<Icon
|
{uploadElement}
|
||||||
as={FaUpload}
|
|
||||||
sx={{
|
|
||||||
boxSize: 16,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
@ -213,6 +219,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{children}
|
||||||
{!isDropDisabled && (
|
{!isDropDisabled && (
|
||||||
<IAIDroppable
|
<IAIDroppable
|
||||||
data={droppableData}
|
data={droppableData}
|
||||||
@ -220,30 +227,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
dropLabel={dropLabel}
|
dropLabel={dropLabel}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{onClickReset && withResetIcon && imageDTO && (
|
|
||||||
<IAIIconButton
|
|
||||||
onClick={onClickReset}
|
|
||||||
aria-label={resetTooltip}
|
|
||||||
tooltip={resetTooltip}
|
|
||||||
icon={resetIcon}
|
|
||||||
size="sm"
|
|
||||||
variant="link"
|
|
||||||
sx={{
|
|
||||||
position: 'absolute',
|
|
||||||
top: 1,
|
|
||||||
insetInlineEnd: 1,
|
|
||||||
p: 0,
|
|
||||||
minW: 0,
|
|
||||||
svg: {
|
|
||||||
transitionProperty: 'common',
|
|
||||||
transitionDuration: 'normal',
|
|
||||||
fill: 'base.100',
|
|
||||||
_hover: { fill: 'base.50' },
|
|
||||||
filter: resetIconShadow,
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</ImageContextMenu>
|
</ImageContextMenu>
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
import { SystemStyleObject, useColorModeValue } from '@chakra-ui/react';
|
||||||
|
import { MouseEvent, ReactElement, memo } from 'react';
|
||||||
|
import IAIIconButton from './IAIIconButton';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||||
|
tooltip: string;
|
||||||
|
icon?: ReactElement;
|
||||||
|
styleOverrides?: SystemStyleObject;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAIDndImageIcon = (props: Props) => {
|
||||||
|
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||||
|
|
||||||
|
const resetIconShadow = useColorModeValue(
|
||||||
|
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
|
||||||
|
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<IAIIconButton
|
||||||
|
onClick={onClick}
|
||||||
|
aria-label={tooltip}
|
||||||
|
tooltip={tooltip}
|
||||||
|
icon={icon}
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 1,
|
||||||
|
insetInlineEnd: 1,
|
||||||
|
p: 0,
|
||||||
|
minW: 0,
|
||||||
|
svg: {
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
fill: 'base.100',
|
||||||
|
_hover: { fill: 'base.50' },
|
||||||
|
filter: resetIconShadow,
|
||||||
|
},
|
||||||
|
...styleOverrides,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAIDndImageIcon);
|
@ -1,22 +1,19 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box, BoxProps } from '@chakra-ui/react';
|
||||||
import {
|
import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||||
TypesafeDraggableData,
|
import { TypesafeDraggableData } from 'features/dnd/types';
|
||||||
useDraggable,
|
import { memo, useRef } from 'react';
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { MouseEvent, memo, useRef } from 'react';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
type IAIDraggableProps = {
|
type IAIDraggableProps = BoxProps & {
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
data?: TypesafeDraggableData;
|
data?: TypesafeDraggableData;
|
||||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIDraggable = (props: IAIDraggableProps) => {
|
const IAIDraggable = (props: IAIDraggableProps) => {
|
||||||
const { data, disabled, onClick } = props;
|
const { data, disabled, ...rest } = props;
|
||||||
const dndId = useRef(uuidv4());
|
const dndId = useRef(uuidv4());
|
||||||
|
|
||||||
const { attributes, listeners, setNodeRef } = useDraggable({
|
const { attributes, listeners, setNodeRef } = useDraggableTypesafe({
|
||||||
id: dndId.current,
|
id: dndId.current,
|
||||||
disabled,
|
disabled,
|
||||||
data,
|
data,
|
||||||
@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
onClick={onClick}
|
|
||||||
ref={setNodeRef}
|
ref={setNodeRef}
|
||||||
position="absolute"
|
position="absolute"
|
||||||
w="full"
|
w="full"
|
||||||
@ -33,6 +29,7 @@ const IAIDraggable = (props: IAIDraggableProps) => {
|
|||||||
insetInlineStart={0}
|
insetInlineStart={0}
|
||||||
{...attributes}
|
{...attributes}
|
||||||
{...listeners}
|
{...listeners}
|
||||||
|
{...rest}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import {
|
import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks';
|
||||||
TypesafeDroppableData,
|
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||||
isValidDrop,
|
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||||
useDroppable,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import { ReactNode, memo, useRef } from 'react';
|
import { ReactNode, memo, useRef } from 'react';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
|||||||
const { dropLabel, data, disabled } = props;
|
const { dropLabel, data, disabled } = props;
|
||||||
const dndId = useRef(uuidv4());
|
const dndId = useRef(uuidv4());
|
||||||
|
|
||||||
const { isOver, setNodeRef, active } = useDroppable({
|
const { isOver, setNodeRef, active } = useDroppableTypesafe({
|
||||||
id: dndId.current,
|
id: dndId.current,
|
||||||
disabled,
|
disabled,
|
||||||
data,
|
data,
|
||||||
|
@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => {
|
|||||||
|
|
||||||
type IAINoImageFallbackProps = {
|
type IAINoImageFallbackProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
icon?: As;
|
icon?: As | null;
|
||||||
boxSize?: StyleProps['boxSize'];
|
boxSize?: StyleProps['boxSize'];
|
||||||
sx?: ChakraProps['sx'];
|
sx?: ChakraProps['sx'];
|
||||||
};
|
};
|
||||||
@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
|||||||
...props.sx,
|
...props.sx,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
|
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import {
|
import {
|
||||||
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormControlProps,
|
FormControlProps,
|
||||||
|
FormHelperText,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
FormLabelProps,
|
FormLabelProps,
|
||||||
Switch,
|
Switch,
|
||||||
SwitchProps,
|
SwitchProps,
|
||||||
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps {
|
|||||||
formControlProps?: FormControlProps;
|
formControlProps?: FormControlProps;
|
||||||
formLabelProps?: FormLabelProps;
|
formLabelProps?: FormLabelProps;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
helperText?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => {
|
|||||||
formControlProps,
|
formControlProps,
|
||||||
formLabelProps,
|
formLabelProps,
|
||||||
tooltip,
|
tooltip,
|
||||||
|
helperText,
|
||||||
...rest
|
...rest
|
||||||
} = props;
|
} = props;
|
||||||
return (
|
return (
|
||||||
@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => {
|
|||||||
<FormControl
|
<FormControl
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
width={width}
|
width={width}
|
||||||
display="flex"
|
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
{label && (
|
<Flex sx={{ flexDir: 'column', w: 'full' }}>
|
||||||
<FormLabel
|
<Flex sx={{ alignItems: 'center', w: 'full' }}>
|
||||||
my={1}
|
{label && (
|
||||||
flexGrow={1}
|
<FormLabel
|
||||||
sx={{
|
my={1}
|
||||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
flexGrow={1}
|
||||||
...formLabelProps?.sx,
|
sx={{
|
||||||
pe: 4,
|
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||||
}}
|
...formLabelProps?.sx,
|
||||||
{...formLabelProps}
|
pe: 4,
|
||||||
>
|
}}
|
||||||
{label}
|
{...formLabelProps}
|
||||||
</FormLabel>
|
>
|
||||||
)}
|
{label}
|
||||||
<Switch {...rest} />
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<Switch {...rest} />
|
||||||
|
</Flex>
|
||||||
|
{helperText && (
|
||||||
|
<FormHelperText>
|
||||||
|
<Text variant="subtext">{helperText}</Text>
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
);
|
);
|
||||||
|
@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => {
|
|||||||
accent850,
|
accent850,
|
||||||
accent900,
|
accent900,
|
||||||
accent950,
|
accent950,
|
||||||
|
baseAlpha50,
|
||||||
|
baseAlpha100,
|
||||||
|
baseAlpha150,
|
||||||
|
baseAlpha200,
|
||||||
|
baseAlpha250,
|
||||||
|
baseAlpha300,
|
||||||
|
baseAlpha350,
|
||||||
|
baseAlpha400,
|
||||||
|
baseAlpha450,
|
||||||
|
baseAlpha500,
|
||||||
|
baseAlpha550,
|
||||||
|
baseAlpha600,
|
||||||
|
baseAlpha650,
|
||||||
|
baseAlpha700,
|
||||||
|
baseAlpha750,
|
||||||
|
baseAlpha800,
|
||||||
|
baseAlpha850,
|
||||||
|
baseAlpha900,
|
||||||
|
baseAlpha950,
|
||||||
|
accentAlpha50,
|
||||||
|
accentAlpha100,
|
||||||
|
accentAlpha150,
|
||||||
|
accentAlpha200,
|
||||||
|
accentAlpha250,
|
||||||
|
accentAlpha300,
|
||||||
|
accentAlpha350,
|
||||||
|
accentAlpha400,
|
||||||
|
accentAlpha450,
|
||||||
|
accentAlpha500,
|
||||||
|
accentAlpha550,
|
||||||
|
accentAlpha600,
|
||||||
|
accentAlpha650,
|
||||||
|
accentAlpha700,
|
||||||
|
accentAlpha750,
|
||||||
|
accentAlpha800,
|
||||||
|
accentAlpha850,
|
||||||
|
accentAlpha900,
|
||||||
|
accentAlpha950,
|
||||||
] = useToken('colors', [
|
] = useToken('colors', [
|
||||||
'base.50',
|
'base.50',
|
||||||
'base.100',
|
'base.100',
|
||||||
@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => {
|
|||||||
'accent.850',
|
'accent.850',
|
||||||
'accent.900',
|
'accent.900',
|
||||||
'accent.950',
|
'accent.950',
|
||||||
|
'baseAlpha.50',
|
||||||
|
'baseAlpha.100',
|
||||||
|
'baseAlpha.150',
|
||||||
|
'baseAlpha.200',
|
||||||
|
'baseAlpha.250',
|
||||||
|
'baseAlpha.300',
|
||||||
|
'baseAlpha.350',
|
||||||
|
'baseAlpha.400',
|
||||||
|
'baseAlpha.450',
|
||||||
|
'baseAlpha.500',
|
||||||
|
'baseAlpha.550',
|
||||||
|
'baseAlpha.600',
|
||||||
|
'baseAlpha.650',
|
||||||
|
'baseAlpha.700',
|
||||||
|
'baseAlpha.750',
|
||||||
|
'baseAlpha.800',
|
||||||
|
'baseAlpha.850',
|
||||||
|
'baseAlpha.900',
|
||||||
|
'baseAlpha.950',
|
||||||
|
'accentAlpha.50',
|
||||||
|
'accentAlpha.100',
|
||||||
|
'accentAlpha.150',
|
||||||
|
'accentAlpha.200',
|
||||||
|
'accentAlpha.250',
|
||||||
|
'accentAlpha.300',
|
||||||
|
'accentAlpha.350',
|
||||||
|
'accentAlpha.400',
|
||||||
|
'accentAlpha.450',
|
||||||
|
'accentAlpha.500',
|
||||||
|
'accentAlpha.550',
|
||||||
|
'accentAlpha.600',
|
||||||
|
'accentAlpha.650',
|
||||||
|
'accentAlpha.700',
|
||||||
|
'accentAlpha.750',
|
||||||
|
'accentAlpha.800',
|
||||||
|
'accentAlpha.850',
|
||||||
|
'accentAlpha.900',
|
||||||
|
'accentAlpha.950',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => {
|
|||||||
accent850,
|
accent850,
|
||||||
accent900,
|
accent900,
|
||||||
accent950,
|
accent950,
|
||||||
|
baseAlpha50,
|
||||||
|
baseAlpha100,
|
||||||
|
baseAlpha150,
|
||||||
|
baseAlpha200,
|
||||||
|
baseAlpha250,
|
||||||
|
baseAlpha300,
|
||||||
|
baseAlpha350,
|
||||||
|
baseAlpha400,
|
||||||
|
baseAlpha450,
|
||||||
|
baseAlpha500,
|
||||||
|
baseAlpha550,
|
||||||
|
baseAlpha600,
|
||||||
|
baseAlpha650,
|
||||||
|
baseAlpha700,
|
||||||
|
baseAlpha750,
|
||||||
|
baseAlpha800,
|
||||||
|
baseAlpha850,
|
||||||
|
baseAlpha900,
|
||||||
|
baseAlpha950,
|
||||||
|
accentAlpha50,
|
||||||
|
accentAlpha100,
|
||||||
|
accentAlpha150,
|
||||||
|
accentAlpha200,
|
||||||
|
accentAlpha250,
|
||||||
|
accentAlpha300,
|
||||||
|
accentAlpha350,
|
||||||
|
accentAlpha400,
|
||||||
|
accentAlpha450,
|
||||||
|
accentAlpha500,
|
||||||
|
accentAlpha550,
|
||||||
|
accentAlpha600,
|
||||||
|
accentAlpha650,
|
||||||
|
accentAlpha700,
|
||||||
|
accentAlpha750,
|
||||||
|
accentAlpha800,
|
||||||
|
accentAlpha850,
|
||||||
|
accentAlpha900,
|
||||||
|
accentAlpha950,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
/**
|
/**
|
||||||
* Serialize an object to JSON and back to a new object
|
* Serialize an object to JSON and back to a new object
|
||||||
*/
|
*/
|
||||||
export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj));
|
export const parseify = (obj: unknown) => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(JSON.stringify(obj));
|
||||||
|
} catch {
|
||||||
|
return 'Error parsing object';
|
||||||
|
}
|
||||||
|
};
|
||||||
|
@ -2,10 +2,11 @@ import { ButtonGroup, Flex } from '@chakra-ui/react';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
|
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
||||||
import {
|
import {
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
isStagingSelector,
|
isStagingSelector,
|
||||||
@ -22,7 +23,7 @@ import { isEqual } from 'lodash-es';
|
|||||||
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaMask, FaTrash } from 'react-icons/fa';
|
import { FaMask, FaSave, FaTrash } from 'react-icons/fa';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[canvasSelector, isStagingSelector],
|
[canvasSelector, isStagingSelector],
|
||||||
@ -102,6 +103,10 @@ const IAICanvasMaskOptions = () => {
|
|||||||
const handleToggleEnableMask = () =>
|
const handleToggleEnableMask = () =>
|
||||||
dispatch(setIsMaskEnabled(!isMaskEnabled));
|
dispatch(setIsMaskEnabled(!isMaskEnabled));
|
||||||
|
|
||||||
|
const handleSaveMask = async () => {
|
||||||
|
dispatch(canvasMaskSavedToGallery());
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIPopover
|
<IAIPopover
|
||||||
triggerComponent={
|
triggerComponent={
|
||||||
@ -134,6 +139,9 @@ const IAICanvasMaskOptions = () => {
|
|||||||
pickerColor={maskColor}
|
pickerColor={maskColor}
|
||||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||||
/>
|
/>
|
||||||
|
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||||
|
Save Mask
|
||||||
|
</IAIButton>
|
||||||
<IAIButton size="sm" leftIcon={<FaTrash />} onClick={handleClearMask}>
|
<IAIButton size="sm" leftIcon={<FaTrash />} onClick={handleClearMask}>
|
||||||
{t('unifiedCanvas.clearMask')} (Shift+C)
|
{t('unifiedCanvas.clearMask')} (Shift+C)
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
@ -3,6 +3,10 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
|
|
||||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||||
|
|
||||||
|
export const canvasMaskSavedToGallery = createAction(
|
||||||
|
'canvas/canvasMaskSavedToGallery'
|
||||||
|
);
|
||||||
|
|
||||||
export const canvasCopiedToClipboard = createAction(
|
export const canvasCopiedToClipboard = createAction(
|
||||||
'canvas/canvasCopiedToClipboard'
|
'canvas/canvasCopiedToClipboard'
|
||||||
);
|
);
|
||||||
|
@ -4,14 +4,16 @@ import { skipToken } from '@reduxjs/toolkit/dist/query';
|
|||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'features/dnd/types';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
|
import { FaUndo } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
|
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
|
||||||
import {
|
import {
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
controlNetImageChanged,
|
controlNetImageChanged,
|
||||||
@ -119,11 +121,15 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
imageDTO={controlImage}
|
imageDTO={controlImage}
|
||||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||||
onClickReset={handleResetControlImage}
|
|
||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
resetTooltip="Reset Control Image"
|
>
|
||||||
withResetIcon={Boolean(controlImage)}
|
<IAIDndImageIcon
|
||||||
/>
|
onClick={handleResetControlImage}
|
||||||
|
icon={controlImage ? <FaUndo /> : undefined}
|
||||||
|
tooltip="Reset Control Image"
|
||||||
|
/>
|
||||||
|
</IAIDndImage>
|
||||||
|
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
@ -143,10 +149,13 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
imageDTO={processedControlImage}
|
imageDTO={processedControlImage}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
isDropDisabled={!isEnabled}
|
isDropDisabled={!isEnabled}
|
||||||
onClickReset={handleResetControlImage}
|
>
|
||||||
resetTooltip="Reset Control Image"
|
<IAIDndImageIcon
|
||||||
withResetIcon={Boolean(controlImage)}
|
onClick={handleResetControlImage}
|
||||||
/>
|
icon={controlImage ? <FaUndo /> : undefined}
|
||||||
|
tooltip="Reset Control Image"
|
||||||
|
/>
|
||||||
|
</IAIDndImage>
|
||||||
</Box>
|
</Box>
|
||||||
{pendingControlImages.includes(controlNetId) && (
|
{pendingControlImages.includes(controlNetId) && (
|
||||||
<Flex
|
<Flex
|
||||||
|
@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required<
|
|||||||
/**
|
/**
|
||||||
* Any ControlNet Processor node, with its parameters flagged as required
|
* Any ControlNet Processor node, with its parameters flagged as required
|
||||||
*/
|
*/
|
||||||
export type RequiredControlNetProcessorNode =
|
export type RequiredControlNetProcessorNode = O.Required<
|
||||||
| RequiredCannyImageProcessorInvocation
|
| RequiredCannyImageProcessorInvocation
|
||||||
| RequiredContentShuffleImageProcessorInvocation
|
| RequiredContentShuffleImageProcessorInvocation
|
||||||
| RequiredHedImageProcessorInvocation
|
| RequiredHedImageProcessorInvocation
|
||||||
@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode =
|
|||||||
| RequiredNormalbaeImageProcessorInvocation
|
| RequiredNormalbaeImageProcessorInvocation
|
||||||
| RequiredOpenposeImageProcessorInvocation
|
| RequiredOpenposeImageProcessorInvocation
|
||||||
| RequiredPidiImageProcessorInvocation
|
| RequiredPidiImageProcessorInvocation
|
||||||
| RequiredZoeDepthImageProcessorInvocation;
|
| RequiredZoeDepthImageProcessorInvocation,
|
||||||
|
'id'
|
||||||
|
>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type guard for CannyImageProcessorInvocation
|
* Type guard for CannyImageProcessorInvocation
|
||||||
|
@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { some } from 'lodash-es';
|
import { some } from 'lodash-es';
|
||||||
import { ImageUsage } from './types';
|
import { ImageUsage } from './types';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
|
|
||||||
export const getImageUsage = (state: RootState, image_name: string) => {
|
export const getImageUsage = (state: RootState, image_name: string) => {
|
||||||
const { generation, canvas, nodes, controlNet } = state;
|
const { generation, canvas, nodes, controlNet } = state;
|
||||||
@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
|||||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const isNodesImage = nodes.nodes.some((node) => {
|
const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => {
|
||||||
return some(
|
return some(
|
||||||
node.data.inputs,
|
node.data.inputs,
|
||||||
(input) =>
|
(input) =>
|
||||||
input.type === 'image' && input.value?.image_name === image_name
|
input.type === 'ImageField' && input.value?.image_name === image_name
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -6,23 +6,18 @@ import {
|
|||||||
useSensor,
|
useSensor,
|
||||||
useSensors,
|
useSensors,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
import { logger } from 'app/logging/logger';
|
||||||
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||||
|
import { useScaledModifer } from '../hooks/useScaledCenteredModifer';
|
||||||
|
import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types';
|
||||||
|
import { DndContextTypesafe } from './DndContextTypesafe';
|
||||||
import DragPreview from './DragPreview';
|
import DragPreview from './DragPreview';
|
||||||
import {
|
|
||||||
DndContext,
|
|
||||||
DragEndEvent,
|
|
||||||
DragStartEvent,
|
|
||||||
TypesafeDraggableData,
|
|
||||||
} from './typesafeDnd';
|
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
|
|
||||||
type ImageDndContextProps = PropsWithChildren;
|
const AppDndContext = (props: PropsWithChildren) => {
|
||||||
|
|
||||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
|
||||||
const [activeDragData, setActiveDragData] =
|
const [activeDragData, setActiveDragData] =
|
||||||
useState<TypesafeDraggableData | null>(null);
|
useState<TypesafeDraggableData | null>(null);
|
||||||
const log = logger('images');
|
const log = logger('images');
|
||||||
@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const handleDragStart = useCallback(
|
const handleDragStart = useCallback(
|
||||||
(event: DragStartEvent) => {
|
(event: DragStartEvent) => {
|
||||||
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
log.trace(
|
||||||
|
{ dragData: parseify(event.active.data.current) },
|
||||||
|
'Drag started'
|
||||||
|
);
|
||||||
const activeData = event.active.data.current;
|
const activeData = event.active.data.current;
|
||||||
if (!activeData) {
|
if (!activeData) {
|
||||||
return;
|
return;
|
||||||
@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const handleDragEnd = useCallback(
|
const handleDragEnd = useCallback(
|
||||||
(event: DragEndEvent) => {
|
(event: DragEndEvent) => {
|
||||||
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
log.trace(
|
||||||
|
{ dragData: parseify(event.active.data.current) },
|
||||||
|
'Drag ended'
|
||||||
|
);
|
||||||
const overData = event.over?.data.current;
|
const overData = event.over?.data.current;
|
||||||
if (!activeDragData || !overData) {
|
if (!activeDragData || !overData) {
|
||||||
return;
|
return;
|
||||||
@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
|
|
||||||
const sensors = useSensors(mouseSensor, touchSensor);
|
const sensors = useSensors(mouseSensor, touchSensor);
|
||||||
|
|
||||||
|
const scaledModifier = useScaledModifer();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<DndContext
|
<DndContextTypesafe
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
onDragEnd={handleDragEnd}
|
onDragEnd={handleDragEnd}
|
||||||
sensors={sensors}
|
sensors={sensors}
|
||||||
collisionDetection={pointerWithin}
|
collisionDetection={pointerWithin}
|
||||||
|
autoScroll={false}
|
||||||
>
|
>
|
||||||
{props.children}
|
{props.children}
|
||||||
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
|
<DragOverlay
|
||||||
|
dropAnimation={null}
|
||||||
|
modifiers={[scaledModifier]}
|
||||||
|
style={{
|
||||||
|
width: 'min-content',
|
||||||
|
height: 'min-content',
|
||||||
|
cursor: 'none',
|
||||||
|
userSelect: 'none',
|
||||||
|
// expand overlay to prevent cursor from going outside it and displaying
|
||||||
|
padding: '10rem',
|
||||||
|
}}
|
||||||
|
>
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{activeDragData && (
|
{activeDragData && (
|
||||||
<motion.div
|
<motion.div
|
||||||
@ -98,8 +113,8 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
</DragOverlay>
|
</DragOverlay>
|
||||||
</DndContext>
|
</DndContextTypesafe>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ImageDndContext);
|
export default memo(AppDndContext);
|
@ -0,0 +1,6 @@
|
|||||||
|
import { DndContext } from '@dnd-kit/core';
|
||||||
|
import { DndContextTypesafeProps } from '../types';
|
||||||
|
|
||||||
|
export function DndContextTypesafe(props: DndContextTypesafeProps) {
|
||||||
|
return <DndContext {...props} />;
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { TypesafeDraggableData } from './typesafeDnd';
|
import { TypesafeDraggableData } from '../types';
|
||||||
|
|
||||||
type OverlayDragImageProps = {
|
type OverlayDragImageProps = {
|
||||||
dragData: TypesafeDraggableData | null;
|
dragData: TypesafeDraggableData | null;
|
||||||
@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (props.dragData.payloadType === 'NODE_FIELD') {
|
||||||
|
const { field, fieldTemplate } = props.dragData.payload;
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
p: 2,
|
||||||
|
px: 3,
|
||||||
|
opacity: 0.7,
|
||||||
|
bg: 'base.300',
|
||||||
|
borderRadius: 'base',
|
||||||
|
boxShadow: 'dark-lg',
|
||||||
|
whiteSpace: 'nowrap',
|
||||||
|
fontSize: 'sm',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text>{field.label || fieldTemplate.title}</Text>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
||||||
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
width: '100%',
|
width: 'full',
|
||||||
height: '100%',
|
height: 'full',
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
userSelect: 'none',
|
|
||||||
cursor: 'none',
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Image
|
<Image
|
||||||
@ -62,8 +81,6 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
cursor: 'none',
|
|
||||||
userSelect: 'none',
|
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
@ -0,0 +1,15 @@
|
|||||||
|
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||||
|
import {
|
||||||
|
UseDraggableTypesafeArguments,
|
||||||
|
UseDraggableTypesafeReturnValue,
|
||||||
|
UseDroppableTypesafeArguments,
|
||||||
|
UseDroppableTypesafeReturnValue,
|
||||||
|
} from '../types';
|
||||||
|
|
||||||
|
export function useDroppableTypesafe(props: UseDroppableTypesafeArguments) {
|
||||||
|
return useDroppable(props) as UseDroppableTypesafeReturnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useDraggableTypesafe(props: UseDraggableTypesafeArguments) {
|
||||||
|
return useDraggable(props) as UseDraggableTypesafeReturnValue;
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
import type { Modifier } from '@dnd-kit/core';
|
||||||
|
import { getEventCoordinates } from '@dnd-kit/utilities';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
const selectZoom = createSelector(
|
||||||
|
[stateSelector, activeTabNameSelector],
|
||||||
|
({ nodes }, activeTabName) =>
|
||||||
|
activeTabName === 'nodes' ? nodes.viewport.zoom : 1
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies scaling to the drag transform (if on node editor tab) and centers it on cursor.
|
||||||
|
*/
|
||||||
|
export const useScaledModifer = () => {
|
||||||
|
const zoom = useAppSelector(selectZoom);
|
||||||
|
const modifier: Modifier = useCallback(
|
||||||
|
({ activatorEvent, draggingNodeRect, transform }) => {
|
||||||
|
if (draggingNodeRect && activatorEvent) {
|
||||||
|
const activatorCoordinates = getEventCoordinates(activatorEvent);
|
||||||
|
|
||||||
|
if (!activatorCoordinates) {
|
||||||
|
return transform;
|
||||||
|
}
|
||||||
|
|
||||||
|
const offsetX = activatorCoordinates.x - draggingNodeRect.left;
|
||||||
|
const offsetY = activatorCoordinates.y - draggingNodeRect.top;
|
||||||
|
|
||||||
|
const x = transform.x + offsetX - draggingNodeRect.width / 2;
|
||||||
|
const y = transform.y + offsetY - draggingNodeRect.height / 2;
|
||||||
|
const scaleX = transform.scaleX * zoom;
|
||||||
|
const scaleY = transform.scaleY * zoom;
|
||||||
|
|
||||||
|
return {
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
scaleX,
|
||||||
|
scaleY,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return transform;
|
||||||
|
},
|
||||||
|
[zoom]
|
||||||
|
);
|
||||||
|
|
||||||
|
return modifier;
|
||||||
|
};
|
@ -3,7 +3,6 @@ import {
|
|||||||
Active,
|
Active,
|
||||||
Collision,
|
Collision,
|
||||||
DndContextProps,
|
DndContextProps,
|
||||||
DndContext as OriginalDndContext,
|
|
||||||
Over,
|
Over,
|
||||||
Translate,
|
Translate,
|
||||||
UseDraggableArguments,
|
UseDraggableArguments,
|
||||||
@ -11,6 +10,10 @@ import {
|
|||||||
useDraggable as useOriginalDraggable,
|
useDraggable as useOriginalDraggable,
|
||||||
useDroppable as useOriginalDroppable,
|
useDroppable as useOriginalDroppable,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
|
import {
|
||||||
|
InputFieldTemplate,
|
||||||
|
InputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
type BaseDropData = {
|
type BaseDropData = {
|
||||||
@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
|||||||
actionType: 'REMOVE_FROM_BOARD';
|
actionType: 'REMOVE_FROM_BOARD';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type AddFieldToLinearViewDropData = BaseDropData & {
|
||||||
|
actionType: 'ADD_FIELD_TO_LINEAR';
|
||||||
|
};
|
||||||
|
|
||||||
export type TypesafeDroppableData =
|
export type TypesafeDroppableData =
|
||||||
| CurrentImageDropData
|
| CurrentImageDropData
|
||||||
| InitialImageDropData
|
| InitialImageDropData
|
||||||
@ -71,12 +78,22 @@ export type TypesafeDroppableData =
|
|||||||
| AddToBatchDropData
|
| AddToBatchDropData
|
||||||
| NodesMultiImageDropData
|
| NodesMultiImageDropData
|
||||||
| AddToBoardDropData
|
| AddToBoardDropData
|
||||||
| RemoveFromBoardDropData;
|
| RemoveFromBoardDropData
|
||||||
|
| AddFieldToLinearViewDropData;
|
||||||
|
|
||||||
type BaseDragData = {
|
type BaseDragData = {
|
||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type NodeFieldDraggableData = BaseDragData & {
|
||||||
|
payloadType: 'NODE_FIELD';
|
||||||
|
payload: {
|
||||||
|
nodeId: string;
|
||||||
|
field: InputFieldValue;
|
||||||
|
fieldTemplate: InputFieldTemplate;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export type ImageDraggableData = BaseDragData & {
|
export type ImageDraggableData = BaseDragData & {
|
||||||
payloadType: 'IMAGE_DTO';
|
payloadType: 'IMAGE_DTO';
|
||||||
payload: { imageDTO: ImageDTO };
|
payload: { imageDTO: ImageDTO };
|
||||||
@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & {
|
|||||||
payload: { imageDTOs: ImageDTO[] };
|
payload: { imageDTOs: ImageDTO[] };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
export type TypesafeDraggableData =
|
||||||
|
| NodeFieldDraggableData
|
||||||
|
| ImageDraggableData
|
||||||
|
| ImageDTOsDraggableData;
|
||||||
|
|
||||||
interface UseDroppableTypesafeArguments
|
export interface UseDroppableTypesafeArguments
|
||||||
extends Omit<UseDroppableArguments, 'data'> {
|
extends Omit<UseDroppableArguments, 'data'> {
|
||||||
data?: TypesafeDroppableData;
|
data?: TypesafeDroppableData;
|
||||||
}
|
}
|
||||||
|
|
||||||
type UseDroppableTypesafeReturnValue = Omit<
|
export type UseDroppableTypesafeReturnValue = Omit<
|
||||||
ReturnType<typeof useOriginalDroppable>,
|
ReturnType<typeof useOriginalDroppable>,
|
||||||
'active' | 'over'
|
'active' | 'over'
|
||||||
> & {
|
> & {
|
||||||
@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit<
|
|||||||
over: TypesafeOver | null;
|
over: TypesafeOver | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useDroppable(props: UseDroppableTypesafeArguments) {
|
export interface UseDraggableTypesafeArguments
|
||||||
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UseDraggableTypesafeArguments
|
|
||||||
extends Omit<UseDraggableArguments, 'data'> {
|
extends Omit<UseDraggableArguments, 'data'> {
|
||||||
data?: TypesafeDraggableData;
|
data?: TypesafeDraggableData;
|
||||||
}
|
}
|
||||||
|
|
||||||
type UseDraggableTypesafeReturnValue = Omit<
|
export type UseDraggableTypesafeReturnValue = Omit<
|
||||||
ReturnType<typeof useOriginalDraggable>,
|
ReturnType<typeof useOriginalDraggable>,
|
||||||
'active' | 'over'
|
'active' | 'over'
|
||||||
> & {
|
> & {
|
||||||
@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit<
|
|||||||
over: TypesafeOver | null;
|
over: TypesafeOver | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useDraggable(props: UseDraggableTypesafeArguments) {
|
export interface TypesafeActive extends Omit<Active, 'data'> {
|
||||||
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
|
||||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TypesafeOver extends Omit<Over, 'data'> {
|
export interface TypesafeOver extends Omit<Over, 'data'> {
|
||||||
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const isValidDrop = (
|
|
||||||
overData: TypesafeDroppableData | undefined,
|
|
||||||
active: TypesafeActive | null
|
|
||||||
) => {
|
|
||||||
if (!overData || !active?.data.current) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { actionType } = overData;
|
|
||||||
const { payloadType } = active.data.current;
|
|
||||||
|
|
||||||
if (overData.id === active.data.current.id) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (actionType) {
|
|
||||||
case 'SET_CURRENT_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_INITIAL_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_CONTROLNET_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_NODES_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO';
|
|
||||||
case 'SET_MULTI_NODES_IMAGE':
|
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
case 'ADD_TO_BATCH':
|
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
case 'ADD_TO_BOARD': {
|
|
||||||
// If the board is the same, don't allow the drop
|
|
||||||
|
|
||||||
// Check the payload types
|
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
if (!isPayloadValid) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
|
||||||
const { imageDTO } = active.data.current.payload;
|
|
||||||
const currentBoard = imageDTO.board_id ?? 'none';
|
|
||||||
const destinationBoard = overData.context.boardId;
|
|
||||||
|
|
||||||
return currentBoard !== destinationBoard;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
|
||||||
// TODO (multi-select)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
case 'REMOVE_FROM_BOARD': {
|
|
||||||
// If the board is the same, don't allow the drop
|
|
||||||
|
|
||||||
// Check the payload types
|
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
|
||||||
if (!isPayloadValid) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
|
||||||
const { imageDTO } = active.data.current.payload;
|
|
||||||
const currentBoard = imageDTO.board_id;
|
|
||||||
|
|
||||||
return currentBoard !== 'none';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_DTOS') {
|
|
||||||
// TODO (multi-select)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
interface DragEvent {
|
interface DragEvent {
|
||||||
activatorEvent: Event;
|
activatorEvent: Event;
|
||||||
active: TypesafeActive;
|
active: TypesafeActive;
|
||||||
@ -240,6 +168,3 @@ export interface DndContextTypesafeProps
|
|||||||
onDragEnd?(event: DragEndEvent): void;
|
onDragEnd?(event: DragEndEvent): void;
|
||||||
onDragCancel?(event: DragCancelEvent): void;
|
onDragCancel?(event: DragCancelEvent): void;
|
||||||
}
|
}
|
||||||
export function DndContext(props: DndContextTypesafeProps) {
|
|
||||||
return <OriginalDndContext {...props} />;
|
|
||||||
}
|
|
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
87
invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import { TypesafeActive, TypesafeDroppableData } from '../types';
|
||||||
|
|
||||||
|
export const isValidDrop = (
|
||||||
|
overData: TypesafeDroppableData | undefined,
|
||||||
|
active: TypesafeActive | null
|
||||||
|
) => {
|
||||||
|
if (!overData || !active?.data.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { actionType } = overData;
|
||||||
|
const { payloadType } = active.data.current;
|
||||||
|
|
||||||
|
if (overData.id === active.data.current.id) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (actionType) {
|
||||||
|
case 'ADD_FIELD_TO_LINEAR':
|
||||||
|
return payloadType === 'NODE_FIELD';
|
||||||
|
case 'SET_CURRENT_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_INITIAL_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_CONTROLNET_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_NODES_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SET_MULTI_NODES_IMAGE':
|
||||||
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
case 'ADD_TO_BATCH':
|
||||||
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
case 'ADD_TO_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
|
return currentBoard !== destinationBoard;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case 'REMOVE_FROM_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id;
|
||||||
|
|
||||||
|
return currentBoard !== 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
@ -11,7 +11,6 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|||||||
import { BoardDTO } from 'services/api/types';
|
import { BoardDTO } from 'services/api/types';
|
||||||
import AutoAddIcon from '../AutoAddIcon';
|
import AutoAddIcon from '../AutoAddIcon';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
import { AddToBoardDropData } from 'features/dnd/types';
|
||||||
|
|
||||||
interface GalleryBoardProps {
|
interface GalleryBoardProps {
|
||||||
board: BoardDTO;
|
board: BoardDTO;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { As, Badge, Flex } from '@chakra-ui/react';
|
import { As, Badge, Flex } from '@chakra-ui/react';
|
||||||
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { TypesafeDroppableData } from 'features/dnd/types';
|
||||||
import { BoardId } from 'features/gallery/store/types';
|
import { BoardId } from 'features/gallery/store/types';
|
||||||
import { ReactNode } from 'react';
|
import { ReactNode } from 'react';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
import InvokeAILogoImage from 'assets/images/logo.png';
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||||
|
import { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||||
import {
|
import {
|
||||||
boardIdSelected,
|
|
||||||
autoAddBoardIdChanged,
|
autoAddBoardIdChanged,
|
||||||
|
boardIdSelected,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
|
import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
|
||||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
|
@ -11,7 +11,6 @@ import {
|
|||||||
autoAssignBoardOnClickChanged,
|
autoAssignBoardOnClickChanged,
|
||||||
setGalleryImageMinimumWidth,
|
setGalleryImageMinimumWidth,
|
||||||
shouldAutoSwitchChanged,
|
shouldAutoSwitchChanged,
|
||||||
shouldShowDeleteButtonChanged,
|
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { ChangeEvent, useCallback } from 'react';
|
import { ChangeEvent, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -26,14 +25,12 @@ const selector = createSelector(
|
|||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
shouldAutoSwitch,
|
shouldAutoSwitch,
|
||||||
autoAssignBoardOnClick,
|
autoAssignBoardOnClick,
|
||||||
shouldShowDeleteButton,
|
|
||||||
} = state.gallery;
|
} = state.gallery;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
shouldAutoSwitch,
|
shouldAutoSwitch,
|
||||||
autoAssignBoardOnClick,
|
autoAssignBoardOnClick,
|
||||||
shouldShowDeleteButton,
|
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -43,12 +40,8 @@ const GallerySettingsPopover = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const {
|
const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } =
|
||||||
galleryImageMinimumWidth,
|
useAppSelector(selector);
|
||||||
shouldAutoSwitch,
|
|
||||||
autoAssignBoardOnClick,
|
|
||||||
shouldShowDeleteButton,
|
|
||||||
} = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleChangeGalleryImageMinimumWidth = useCallback(
|
const handleChangeGalleryImageMinimumWidth = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -68,13 +61,6 @@ const GallerySettingsPopover = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeShowDeleteButton = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
dispatch(shouldShowDeleteButtonChanged(e.target.checked));
|
|
||||||
},
|
|
||||||
[dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIPopover
|
<IAIPopover
|
||||||
triggerComponent={
|
triggerComponent={
|
||||||
@ -90,7 +76,7 @@ const GallerySettingsPopover = () => {
|
|||||||
<IAISlider
|
<IAISlider
|
||||||
value={galleryImageMinimumWidth}
|
value={galleryImageMinimumWidth}
|
||||||
onChange={handleChangeGalleryImageMinimumWidth}
|
onChange={handleChangeGalleryImageMinimumWidth}
|
||||||
min={32}
|
min={45}
|
||||||
max={256}
|
max={256}
|
||||||
hideTooltip={true}
|
hideTooltip={true}
|
||||||
label={t('gallery.galleryImageSize')}
|
label={t('gallery.galleryImageSize')}
|
||||||
@ -102,11 +88,6 @@ const GallerySettingsPopover = () => {
|
|||||||
isChecked={shouldAutoSwitch}
|
isChecked={shouldAutoSwitch}
|
||||||
onChange={handleChangeAutoSwitch}
|
onChange={handleChangeAutoSwitch}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
|
||||||
label="Show Delete Button"
|
|
||||||
isChecked={shouldShowDeleteButton}
|
|
||||||
onChange={handleChangeShowDeleteButton}
|
|
||||||
/>
|
|
||||||
<IAISimpleCheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('gallery.autoAssignBoardOnClick')}
|
label={t('gallery.autoAssignBoardOnClick')}
|
||||||
isChecked={autoAssignBoardOnClick}
|
isChecked={autoAssignBoardOnClick}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import { MenuList } from '@chakra-ui/react';
|
import { MenuList } from '@chakra-ui/react';
|
||||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
import {
|
||||||
|
IAIContextMenu,
|
||||||
|
IAIContextMenuProps,
|
||||||
|
} from 'common/components/IAIContextMenu';
|
||||||
import { MouseEvent, memo, useCallback } from 'react';
|
import { MouseEvent, memo, useCallback } from 'react';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { menuListMotionProps } from 'theme/components/menu';
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
@ -12,7 +15,7 @@ import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
imageDTO: ImageDTO | undefined;
|
imageDTO: ImageDTO | undefined;
|
||||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
children: IAIContextMenuProps<HTMLDivElement>['children'];
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -33,7 +36,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ContextMenu<HTMLDivElement>
|
<IAIContextMenu<HTMLDivElement>
|
||||||
menuProps={{ size: 'sm', isLazy: true }}
|
menuProps={{ size: 'sm', isLazy: true }}
|
||||||
menuButtonProps={{
|
menuButtonProps={{
|
||||||
bg: 'transparent',
|
bg: 'transparent',
|
||||||
@ -68,7 +71,7 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</ContextMenu>
|
</IAIContextMenu>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -5,13 +5,21 @@ import {
|
|||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/changeBoardModal/store/slice';
|
} from 'features/changeBoardModal/store/slice';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { FaFolder, FaTrash } from 'react-icons/fa';
|
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||||
|
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||||
|
import {
|
||||||
|
useStarImagesMutation,
|
||||||
|
useUnstarImagesMutation,
|
||||||
|
} from '../../../../services/api/endpoints/images';
|
||||||
|
|
||||||
const MultipleSelectionMenuItems = () => {
|
const MultipleSelectionMenuItems = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selection = useAppSelector((state) => state.gallery.selection);
|
const selection = useAppSelector((state) => state.gallery.selection);
|
||||||
|
|
||||||
|
const [starImages] = useStarImagesMutation();
|
||||||
|
const [unstarImages] = useUnstarImagesMutation();
|
||||||
|
|
||||||
const handleChangeBoard = useCallback(() => {
|
const handleChangeBoard = useCallback(() => {
|
||||||
dispatch(imagesToChangeSelected(selection));
|
dispatch(imagesToChangeSelected(selection));
|
||||||
dispatch(isModalOpenChanged(true));
|
dispatch(isModalOpenChanged(true));
|
||||||
@ -21,8 +29,37 @@ const MultipleSelectionMenuItems = () => {
|
|||||||
dispatch(imagesToDeleteSelected(selection));
|
dispatch(imagesToDeleteSelected(selection));
|
||||||
}, [dispatch, selection]);
|
}, [dispatch, selection]);
|
||||||
|
|
||||||
|
const handleStarSelection = useCallback(() => {
|
||||||
|
starImages({ imageDTOs: selection });
|
||||||
|
}, [starImages, selection]);
|
||||||
|
|
||||||
|
const handleUnstarSelection = useCallback(() => {
|
||||||
|
unstarImages({ imageDTOs: selection });
|
||||||
|
}, [unstarImages, selection]);
|
||||||
|
|
||||||
|
const areAllStarred = useMemo(() => {
|
||||||
|
return selection.every((img) => img.starred);
|
||||||
|
}, [selection]);
|
||||||
|
|
||||||
|
const areAllUnstarred = useMemo(() => {
|
||||||
|
return selection.every((img) => !img.starred);
|
||||||
|
}, [selection]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
{areAllStarred && (
|
||||||
|
<MenuItem
|
||||||
|
icon={<MdStarBorder />}
|
||||||
|
onClickCapture={handleUnstarSelection}
|
||||||
|
>
|
||||||
|
Unstar All
|
||||||
|
</MenuItem>
|
||||||
|
)}
|
||||||
|
{(areAllUnstarred || (!areAllStarred && !areAllUnstarred)) && (
|
||||||
|
<MenuItem icon={<MdStar />} onClickCapture={handleStarSelection}>
|
||||||
|
Star All
|
||||||
|
</MenuItem>
|
||||||
|
)}
|
||||||
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
||||||
Change Board
|
Change Board
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
@ -29,10 +29,15 @@ import {
|
|||||||
FaShare,
|
FaShare,
|
||||||
FaTrash,
|
FaTrash,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
import {
|
||||||
|
useGetImageMetadataQuery,
|
||||||
|
useStarImagesMutation,
|
||||||
|
useUnstarImagesMutation,
|
||||||
|
} from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { useDebounce } from 'use-debounce';
|
import { useDebounce } from 'use-debounce';
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||||
|
|
||||||
type SingleSelectionMenuItemsProps = {
|
type SingleSelectionMenuItemsProps = {
|
||||||
imageDTO: ImageDTO;
|
imageDTO: ImageDTO;
|
||||||
@ -59,6 +64,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
: debouncedMetadataQueryArg ?? skipToken
|
: debouncedMetadataQueryArg ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [starImages] = useStarImagesMutation();
|
||||||
|
const [unstarImages] = useUnstarImagesMutation();
|
||||||
|
|
||||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||||
useCopyImageToClipboard();
|
useCopyImageToClipboard();
|
||||||
|
|
||||||
@ -127,6 +135,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
copyImageToClipboard(imageDTO.image_url);
|
copyImageToClipboard(imageDTO.image_url);
|
||||||
}, [copyImageToClipboard, imageDTO.image_url]);
|
}, [copyImageToClipboard, imageDTO.image_url]);
|
||||||
|
|
||||||
|
const handleStarImage = useCallback(() => {
|
||||||
|
if (imageDTO) starImages({ imageDTOs: [imageDTO] });
|
||||||
|
}, [starImages, imageDTO]);
|
||||||
|
|
||||||
|
const handleUnstarImage = useCallback(() => {
|
||||||
|
if (imageDTO) unstarImages({ imageDTOs: [imageDTO] });
|
||||||
|
}, [unstarImages, imageDTO]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
@ -196,6 +212,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
<MenuItem icon={<FaFolder />} onClickCapture={handleChangeBoard}>
|
||||||
Change Board
|
Change Board
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
{imageDTO.starred ? (
|
||||||
|
<MenuItem icon={<MdStar />} onClickCapture={handleUnstarImage}>
|
||||||
|
Unstar Image
|
||||||
|
</MenuItem>
|
||||||
|
) : (
|
||||||
|
<MenuItem icon={<MdStarBorder />} onClickCapture={handleStarImage}>
|
||||||
|
Star Image
|
||||||
|
</MenuItem>
|
||||||
|
)}
|
||||||
<MenuItem
|
<MenuItem
|
||||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||||
icon={<FaTrash />}
|
icon={<FaTrash />}
|
||||||
|
@ -52,11 +52,13 @@ const ImageGalleryContent = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<VStack
|
<VStack
|
||||||
|
layerStyle="first"
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
|
p: 2,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Box sx={{ w: 'full' }}>
|
<Box sx={{ w: 'full' }}>
|
||||||
|
@ -1,17 +1,23 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
||||||
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import {
|
import {
|
||||||
ImageDTOsDraggableData,
|
ImageDTOsDraggableData,
|
||||||
ImageDraggableData,
|
ImageDraggableData,
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'features/dnd/types';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
|
||||||
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
|
||||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
|
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||||
|
import {
|
||||||
|
useGetImageDTOQuery,
|
||||||
|
useStarImagesMutation,
|
||||||
|
useUnstarImagesMutation,
|
||||||
|
} from 'services/api/endpoints/images';
|
||||||
|
import IAIDndImageIcon from '../../../../common/components/IAIDndImageIcon';
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
@ -21,9 +27,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { imageName } = props;
|
const { imageName } = props;
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||||
const shouldShowDeleteButton = useAppSelector(
|
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||||
(state) => state.gallery.shouldShowDeleteButton
|
|
||||||
);
|
|
||||||
|
|
||||||
const { handleClick, isSelected, selection, selectionCount } =
|
const { handleClick, isSelected, selection, selectionCount } =
|
||||||
useMultiselect(imageDTO);
|
useMultiselect(imageDTO);
|
||||||
@ -59,6 +63,35 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
}
|
}
|
||||||
}, [imageDTO, selection, selectionCount]);
|
}, [imageDTO, selection, selectionCount]);
|
||||||
|
|
||||||
|
const [starImages] = useStarImagesMutation();
|
||||||
|
const [unstarImages] = useUnstarImagesMutation();
|
||||||
|
|
||||||
|
const toggleStarredState = useCallback(() => {
|
||||||
|
if (imageDTO) {
|
||||||
|
if (imageDTO.starred) {
|
||||||
|
unstarImages({ imageDTOs: [imageDTO] });
|
||||||
|
}
|
||||||
|
if (!imageDTO.starred) {
|
||||||
|
starImages({ imageDTOs: [imageDTO] });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [starImages, unstarImages, imageDTO]);
|
||||||
|
|
||||||
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
|
const handleMouseOver = useCallback(() => {
|
||||||
|
setIsHovered(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleMouseOut = useCallback(() => {
|
||||||
|
setIsHovered(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const starIcon = useMemo(() => {
|
||||||
|
if (imageDTO?.starred) return <MdStar size="20" />;
|
||||||
|
if (!imageDTO?.starred && isHovered) return <MdStarBorder size="20" />;
|
||||||
|
}, [imageDTO?.starred, isHovered]);
|
||||||
|
|
||||||
if (!imageDTO) {
|
if (!imageDTO) {
|
||||||
return <IAIFillSkeleton />;
|
return <IAIFillSkeleton />;
|
||||||
}
|
}
|
||||||
@ -80,16 +113,34 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
draggableData={draggableData}
|
draggableData={draggableData}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
minSize={0}
|
minSize={0}
|
||||||
onClickReset={handleDelete}
|
|
||||||
imageSx={{ w: 'full', h: 'full' }}
|
imageSx={{ w: 'full', h: 'full' }}
|
||||||
isDropDisabled={true}
|
isDropDisabled={true}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
thumbnail={true}
|
thumbnail={true}
|
||||||
withHoverOverlay
|
withHoverOverlay
|
||||||
resetIcon={<FaTrash />}
|
onMouseOver={handleMouseOver}
|
||||||
resetTooltip="Delete image"
|
onMouseOut={handleMouseOut}
|
||||||
withResetIcon={shouldShowDeleteButton} // removed bc it's too easy to accidentally delete images
|
>
|
||||||
/>
|
<>
|
||||||
|
<IAIDndImageIcon
|
||||||
|
onClick={toggleStarredState}
|
||||||
|
icon={starIcon}
|
||||||
|
tooltip={imageDTO.starred ? 'Unstar' : 'Star'}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{isHovered && shift && (
|
||||||
|
<IAIDndImageIcon
|
||||||
|
onClick={handleDelete}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
tooltip="Delete"
|
||||||
|
styleOverrides={{
|
||||||
|
bottom: 2,
|
||||||
|
top: 'auto',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
</IAIDndImage>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
|||||||
options: {
|
options: {
|
||||||
scrollbars: {
|
scrollbars: {
|
||||||
visibility: 'auto',
|
visibility: 'auto',
|
||||||
autoHide: 'leave',
|
autoHide: 'scroll',
|
||||||
autoHideDelay: 1300,
|
autoHideDelay: 1300,
|
||||||
theme: 'os-theme-dark',
|
theme: 'os-theme-dark',
|
||||||
},
|
},
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user