mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/inpaint_gen
This commit is contained in:
@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile"]
|
||||
infill_methods = ["tile", "lama"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
import uvicorn
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@ -13,7 +13,6 @@ from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
@ -30,9 +29,12 @@ from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
import torch
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
@ -40,7 +42,6 @@ app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
@ -122,6 +123,7 @@ def custom_openapi():
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
output_schema["class"] = "output"
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
@ -130,8 +132,8 @@ def custom_openapi():
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@ -140,8 +142,8 @@ def custom_openapi():
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
@ -207,6 +209,17 @@ def invoke_api():
|
||||
|
||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
@ -71,6 +71,9 @@ class FieldDescriptions:
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
@ -140,6 +143,7 @@ class UIType(str, Enum):
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
# endregion
|
||||
|
||||
|
||||
@ -166,6 +170,7 @@ class _InputField(BaseModel):
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
@ -178,6 +183,7 @@ class _OutputField(BaseModel):
|
||||
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
|
||||
def InputField(
|
||||
@ -211,6 +217,7 @@ def InputField(
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
@ -269,6 +276,7 @@ def InputField(
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -302,6 +310,7 @@ def OutputField(
|
||||
repr: bool = True,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
@ -348,6 +357,7 @@ def OutputField(
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -376,7 +386,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""Base class for all invocation outputs"""
|
||||
|
||||
# All outputs must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
# type: Literal['your_output_name'] # noqa f821
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
@ -389,6 +399,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
@ -410,7 +427,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
# All invocations must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
# type: Literal['your_output_name'] # noqa f821
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
@ -449,6 +466,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
@ -485,7 +505,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||
)
|
||||
|
@ -233,7 +233,7 @@ class SDXLPromptInvocationBase:
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=True,
|
||||
requires_pooled=get_pooled,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
|
@ -8,7 +8,7 @@ import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
@ -41,6 +41,39 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Blank Image")
|
||||
@tags("image")
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["blank_image"] = "blank_image"
|
||||
|
||||
# Inputs
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
height: int = InputField(default=512, description="The height of the image")
|
||||
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Crop Image")
|
||||
@tags("image", "crop")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
|
@ -1,23 +1,25 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
"lama",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
@ -28,6 +30,11 @@ INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_lama(im: Image.Image) -> Image.Image:
|
||||
lama = LaMA()
|
||||
return lama(im)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@ -90,7 +97,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask is False).sum()
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
@ -218,3 +225,34 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("LaMa Infill")
|
||||
@tags("image", "inpaint")
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
type: Literal["infill_lama"] = "infill_lama"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@ -168,22 +169,24 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[InpaintMaskField] = InputField(
|
||||
@ -517,7 +520,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@title("Latents to Image")
|
||||
@tags("latents", "image", "vae")
|
||||
@tags("latents", "image", "vae", "l2i")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
@ -705,7 +708,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@title("Image to Latents")
|
||||
@tags("latents", "image", "vae")
|
||||
@tags("latents", "image", "vae", "i2l")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
@ -786,3 +789,81 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
|
||||
@title("Blend Latents")
|
||||
@tags("latents", "blend")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
type: Literal["lblend"] = "lblend"
|
||||
|
||||
# Inputs
|
||||
latents_a: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents_b: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise "Latents to blend must be the same size."
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, blended_latents)
|
||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||
|
@ -21,7 +21,7 @@ class AddInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a + self.b)
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@title("Subtract Integers")
|
||||
@ -36,7 +36,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a - self.b)
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@title("Multiply Integers")
|
||||
@ -51,7 +51,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a * self.b)
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@title("Divide Integers")
|
||||
@ -66,7 +66,7 @@ class DivideInvocation(BaseInvocation):
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=int(self.a / self.b))
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@title("Random Integer")
|
||||
@ -81,4 +81,4 @@ class RandomIntInvocation(BaseInvocation):
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
@ -32,6 +32,7 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
created_by: Optional[str] = Field(description="The name of the creator of the image")
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
|
@ -72,7 +72,7 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@title("Main Model Loader")
|
||||
@title("Main Model")
|
||||
@tags("model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@ -179,7 +179,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("LoRA Loader")
|
||||
@title("LoRA")
|
||||
@tags("lora", "model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -257,7 +257,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("SDXL LoRA Loader")
|
||||
@title("SDXL LoRA")
|
||||
@tags("sdxl", "lora", "model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -356,7 +356,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("VAE Loader")
|
||||
@title("VAE")
|
||||
@tags("vae", "model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
@ -406,7 +406,7 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@title("ONNX Model Loader")
|
||||
@title("ONNX Main Model")
|
||||
@tags("onnx", "model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
@ -33,7 +33,7 @@ class BooleanOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single boolean"""
|
||||
|
||||
type: Literal["boolean_output"] = "boolean_output"
|
||||
a: bool = OutputField(description="The output boolean")
|
||||
value: bool = OutputField(description="The output boolean")
|
||||
|
||||
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
@ -42,9 +42,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[bool] = OutputField(
|
||||
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@ -55,10 +53,10 @@ class BooleanInvocation(BaseInvocation):
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
# Inputs
|
||||
a: bool = InputField(default=False, description="The boolean value")
|
||||
value: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(a=self.a)
|
||||
return BooleanOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Boolean Primitive Collection")
|
||||
@ -70,7 +68,7 @@ class BooleanCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
@ -86,7 +84,7 @@ class IntegerOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single integer"""
|
||||
|
||||
type: Literal["integer_output"] = "integer_output"
|
||||
a: int = OutputField(description="The output integer")
|
||||
value: int = OutputField(description="The output integer")
|
||||
|
||||
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
@ -95,9 +93,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = OutputField(
|
||||
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@ -108,10 +104,10 @@ class IntegerInvocation(BaseInvocation):
|
||||
type: Literal["integer"] = "integer"
|
||||
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description="The integer value")
|
||||
value: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a)
|
||||
return IntegerOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Integer Primitive Collection")
|
||||
@ -139,7 +135,7 @@ class FloatOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single float"""
|
||||
|
||||
type: Literal["float_output"] = "float_output"
|
||||
a: float = OutputField(description="The output float")
|
||||
value: float = OutputField(description="The output float")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
@ -148,9 +144,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["float_collection_output"] = "float_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = OutputField(
|
||||
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
||||
)
|
||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@ -161,10 +155,10 @@ class FloatInvocation(BaseInvocation):
|
||||
type: Literal["float"] = "float"
|
||||
|
||||
# Inputs
|
||||
param: float = InputField(default=0.0, description="The float value")
|
||||
value: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(a=self.param)
|
||||
return FloatOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Float Primitive Collection")
|
||||
@ -176,7 +170,7 @@ class FloatCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
@ -192,7 +186,7 @@ class StringOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single string"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = OutputField(description="The output string")
|
||||
value: str = OutputField(description="The output string")
|
||||
|
||||
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
@ -201,9 +195,7 @@ class StringCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["string_collection_output"] = "string_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[str] = OutputField(
|
||||
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
||||
)
|
||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@ -214,10 +206,10 @@ class StringInvocation(BaseInvocation):
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
# Inputs
|
||||
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
return StringOutput(value=self.value)
|
||||
|
||||
|
||||
@title("String Primitive Collection")
|
||||
@ -229,7 +221,7 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
@ -262,9 +254,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["image_collection_output"] = "image_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = OutputField(
|
||||
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
||||
)
|
||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@ -353,7 +343,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
default_factory=list,
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
@ -384,7 +373,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
@ -429,9 +418,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
|
||||
type: Literal["color_collection_output"] = "color_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ColorField] = OutputField(
|
||||
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
||||
)
|
||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@ -474,7 +461,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
# Outputs
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
default_factory=list,
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("SDXL Main Model Loader")
|
||||
@title("SDXL Main Model")
|
||||
@tags("model", "sdxl")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
@ -122,7 +122,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Model Loader")
|
||||
@title("SDXL Refiner Model")
|
||||
@tags("model", "sdxl", "refiner")
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
8
invokeai/app/services/config/__init__.py
Normal file
8
invokeai/app/services/config/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""
|
||||
Init file for InvokeAI configure package
|
||||
"""
|
||||
|
||||
from .invokeai_config import ( # noqa F401
|
||||
InvokeAIAppConfig,
|
||||
get_invokeai_config,
|
||||
)
|
239
invokeai/app/services/config/base.py
Normal file
239
invokeai/app/services/config/base.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Base class for the InvokeAI configuration system.
|
||||
It defines a type of pydantic BaseSettings object that
|
||||
is able to read and write from an omegaconf-based config file,
|
||||
with overriding of settings from environment variables and/or
|
||||
the command line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import os
|
||||
import pydoc
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == Union:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=int_or_float_or_str,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
return str(value)
|
@ -10,37 +10,49 @@ categories returned by `invokeai --help`. The file looks like this:
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
autoimport_dir: null
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_cache_size: 6
|
||||
max_vram_cache_size: 0.5
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
port: 9090
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
patchmatch: true
|
||||
ignore_missing_core_models: false
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
lora_dir: null
|
||||
embedding_dir: null
|
||||
controlnet_dir: null
|
||||
conf_path: configs/models.yaml
|
||||
models_dir: models
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
db_dir: databases
|
||||
outdir: /home/lstein/invokeai-main/outputs
|
||||
use_memory_db: false
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
log_format: plain
|
||||
log_level: info
|
||||
Model Cache:
|
||||
ram: 13.5
|
||||
vram: 0.25
|
||||
lazy_offload: true
|
||||
Device:
|
||||
device: auto
|
||||
precision: auto
|
||||
Generation:
|
||||
sequential_guidance: false
|
||||
attention_type: xformers
|
||||
attention_slice_size: auto
|
||||
force_tiled_decode: false
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
@ -54,24 +66,23 @@ InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--xformers_enabled'])
|
||||
conf.parse_args(argv=['--log_tokenization'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
conf = InvokeAIAppConfig(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# False
|
||||
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
@ -93,7 +104,7 @@ Typical usage at the top level file:
|
||||
# get global configuration and print its cache size
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.max_cache_size)
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
@ -101,8 +112,7 @@ Typical usage in a backend module:
|
||||
|
||||
# get global configuration and print its cache size value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.max_cache_size)
|
||||
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Computed properties:
|
||||
|
||||
@ -159,15 +169,15 @@ two configs are kept in separate sections of the config file:
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_type_hints, Optional
|
||||
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pydantic import Field, parse_obj_as
|
||||
|
||||
from .base import InvokeAISettings
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -175,195 +185,6 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
@ -378,6 +199,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
@ -385,20 +208,14 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
@ -409,16 +226,43 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
|
||||
# LOGGING
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
@ -541,11 +385,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
@ -561,6 +400,27 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> float:
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
return self.always_use_cpu or self.device == "cpu"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""
|
||||
Return true if enable_xformers is false (reversed logic)
|
||||
and attention type is not set to xformers.
|
||||
"""
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
@ -570,19 +430,19 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", a=512),
|
||||
"height": IntegerInvocation(id="height", a=512),
|
||||
"seed": IntegerInvocation(id="seed", a=-1),
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
@ -29,15 +29,15 @@ def create_text_to_image() -> LibraryGraph:
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="a"),
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="a"),
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="a"),
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
invocation: BaseInvocation
|
||||
collector: "InvocationStatsServiceBase"
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
"""Zero the statistics for the indicated graph."""
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
try:
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
errored.add(graph_id)
|
||||
continue
|
||||
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
||||
for graph_id in errored:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
@ -330,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
|
56
invokeai/backend/image_util/lama.py
Normal file
56
invokeai/backend/image_util/lama.py
Normal file
@ -0,0 +1,56 @@
|
||||
import gc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
model_path = url_or_path
|
||||
print(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_invokeai_config().models_path / "core/misc/lama/lama.pt"
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
|
||||
mask = input_image.split()[-1]
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
return infilled_image
|
@ -21,6 +21,7 @@ from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints, get_args, Any
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
@ -49,7 +50,8 @@ from invokeai.frontend.install.model_install import addModelsForm, process_and_e
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
SingleSelectColumnsSimple,
|
||||
MultiSelectColumns,
|
||||
CenteredButtonPress,
|
||||
FileBox,
|
||||
set_min_terminal_size,
|
||||
@ -71,6 +73,10 @@ warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
def get_literal_fields(field) -> list[Any]:
|
||||
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
@ -80,7 +86,11 @@ Model_dir = "models"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||
PRECISION_CHOICES = get_literal_fields("precision")
|
||||
DEVICE_CHOICES = get_literal_fields("device")
|
||||
ATTENTION_CHOICES = get_literal_fields("attention_type")
|
||||
ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
@ -311,6 +321,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
self.nextrely -= 1
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -337,76 +348,127 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
use_two_lines=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="GPU Management",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.free_gpu_mem = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Free GPU memory after each generation",
|
||||
value=old_opts.free_gpu_mem,
|
||||
max_width=45,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support",
|
||||
value=old_opts.xformers_enabled,
|
||||
max_width=30,
|
||||
relx=50,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.always_use_cpu = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force CPU to be used on GPU systems",
|
||||
value=old_opts.always_use_cpu,
|
||||
relx=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
# old settings for defaults
|
||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||
device = old_opts.device
|
||||
attention_type = old_opts.attention_type
|
||||
attention_slice_size = old_opts.attention_slice_size
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Floating Point Precision",
|
||||
name="Image Generation Options:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.generation_options = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=3,
|
||||
values=GENERATION_OPT_CHOICES,
|
||||
value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)],
|
||||
relx=30,
|
||||
max_height=2,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Floating Point Precision:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.nextrely -= 2
|
||||
self.precision = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
columns=3,
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(PRECISION_CHOICES),
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
begin_entry_at=3,
|
||||
max_height=2,
|
||||
relx=30,
|
||||
max_width=56,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Generation Device:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.device = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(DEVICE_CHOICES),
|
||||
values=DEVICE_CHOICES,
|
||||
value=[DEVICE_CHOICES.index(device)],
|
||||
begin_entry_at=3,
|
||||
relx=30,
|
||||
max_height=2,
|
||||
max_width=60,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Attention Type:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.attention_type = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(ATTENTION_CHOICES),
|
||||
values=ATTENTION_CHOICES,
|
||||
value=[ATTENTION_CHOICES.index(attention_type)],
|
||||
begin_entry_at=3,
|
||||
max_height=2,
|
||||
relx=30,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.attention_type.on_changed = self.show_hide_slice_sizes
|
||||
self.attention_slice_label = self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Attention Slice Size:",
|
||||
relx=5,
|
||||
editable=False,
|
||||
hidden=attention_type != "sliced",
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.attention_slice_size = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(ATTENTION_SLICE_CHOICES),
|
||||
values=ATTENTION_SLICE_CHOICES,
|
||||
value=[ATTENTION_SLICE_CHOICES.index(attention_slice_size)],
|
||||
relx=30,
|
||||
hidden=attention_type != "sliced",
|
||||
max_height=2,
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
self.ram = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
||||
value=clip(old_opts.ram_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
||||
out_of=round(MAX_RAM),
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
@ -417,16 +479,16 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||
name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_vram_cache_size = self.add_widget_intelligent(
|
||||
self.vram = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
||||
value=clip(old_opts.vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
||||
out_of=round(MAX_VRAM * 2) / 2,
|
||||
lowest=0.0,
|
||||
relx=8,
|
||||
@ -434,7 +496,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
else:
|
||||
self.max_vram_cache_size = DummyWidgetValue.zero
|
||||
self.vram_cache_size = DummyWidgetValue.zero
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@ -490,6 +552,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
def show_hide_slice_sizes(self, value):
|
||||
show = ATTENTION_CHOICES[value[0]] == "sliced"
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@ -523,12 +590,9 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
@ -541,6 +605,12 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
new_opts.device = DEVICE_CHOICES[self.device.value[0]]
|
||||
new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]]
|
||||
new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]]
|
||||
generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
|
||||
for v in GENERATION_OPT_CHOICES:
|
||||
setattr(new_opts, v, v in generation_options)
|
||||
|
||||
return new_opts
|
||||
|
||||
@ -635,8 +705,6 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
path = dest / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
maybe_create_models_yaml(root)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
|
@ -20,11 +20,36 @@
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from diffusers.models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_available, is_omegaconf_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
@ -37,35 +62,8 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers.models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .models import BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
@ -1221,9 +1219,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
@ -1662,9 +1657,6 @@ def download_controlnet_from_original_ckpt(
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
@ -1741,7 +1733,7 @@ def convert_ckpt_to_diffusers(
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=use_safetensors and is_safetensors_available(),
|
||||
safe_serialization=use_safetensors,
|
||||
)
|
||||
|
||||
|
||||
@ -1757,7 +1749,4 @@ def convert_controlnet_to_diffusers(
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
safe_serialization=is_safetensors_available(),
|
||||
)
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
@ -341,7 +341,8 @@ class ModelManager(object):
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
max_vram_cache_size=self.app_config.max_vram_cache_size,
|
||||
max_vram_cache_size=self.app_config.vram_cache_size,
|
||||
lazy_offloading=self.app_config.lazy_offload,
|
||||
execution_device=device_type,
|
||||
precision=precision,
|
||||
sequential_offload=sequential_offload,
|
||||
|
@ -5,7 +5,6 @@ from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -175,5 +174,5 @@ def _convert_vae_ckpt_and_cache(
|
||||
vae_config=config,
|
||||
image_size=image_size,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return output_path
|
||||
|
@ -33,7 +33,7 @@ from .diffusion import (
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
)
|
||||
from ..util import normalize_device
|
||||
from ..util import normalize_device, auto_detect_slice_size
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -291,6 +291,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if config.attention_type == "xformers":
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
elif config.attention_type == "sliced":
|
||||
slice_size = config.attention_slice_size
|
||||
if slice_size == "auto":
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
elif slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
return
|
||||
elif config.attention_type == "normal":
|
||||
self.disable_attention_slicing()
|
||||
return
|
||||
elif config.attention_type == "torch-sdp":
|
||||
raise Exception("torch-sdp attention slicing not yet implemented")
|
||||
|
||||
# the remainder if this code is called when attention_type=='auto'
|
||||
if self.unet.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
|
@ -11,4 +11,11 @@ from .devices import ( # noqa: F401
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log # noqa: F401
|
||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
|
||||
from .util import ( # noqa: F401
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
Chdir,
|
||||
)
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
|
32
invokeai/backend/util/attention.py
Normal file
32
invokeai/backend/util/attention.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||
"""
|
||||
Utility routine used for autodetection of optimal slice size
|
||||
for attention mechanism.
|
||||
"""
|
||||
import torch
|
||||
import psutil
|
||||
|
||||
|
||||
def auto_detect_slice_size(latents: torch.Tensor) -> str:
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if latents.device.type in {"cpu", "mps"}:
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif latents.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(latents.device)
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {latents.device}")
|
||||
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
|
||||
return "max"
|
||||
elif torch.backends.mps.is_available():
|
||||
return "max"
|
||||
else:
|
||||
return "balanced"
|
@ -17,13 +17,17 @@ config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
if config.always_use_cpu:
|
||||
if config.use_cpu: # legacy setting - force CPU
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return CPU_DEVICE
|
||||
elif config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return CPU_DEVICE
|
||||
else:
|
||||
return torch.device(config.device)
|
||||
|
||||
|
||||
def choose_precision(device: torch.device) -> str:
|
||||
|
@ -17,8 +17,8 @@ from shutil import get_terminal_size
|
||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 130
|
||||
MIN_LINES = 38
|
||||
MIN_COLS = 150
|
||||
MIN_LINES = 40
|
||||
|
||||
|
||||
class WindowTooSmallException(Exception):
|
||||
@ -177,6 +177,8 @@ class FloatTitleSlider(npyscreen.TitleText):
|
||||
|
||||
|
||||
class SelectColumnBase:
|
||||
"""Base class for selection widget arranged in columns."""
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
@ -253,6 +255,7 @@ class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||
class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def h_select(self, ch):
|
||||
super().h_select(ch)
|
||||
@ -260,7 +263,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
@ -268,15 +273,19 @@ class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
|
||||
self.on_changed = None
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def when_value_edited(self):
|
||||
self.h_select(self.cursor_line)
|
||||
def h_cursor_line_right(self, ch):
|
||||
self.h_exit_down("bye bye")
|
||||
|
||||
def h_cursor_line_left(self, ch):
|
||||
self.h_exit_up("bye bye")
|
||||
|
||||
|
||||
class SingleSelectColumns(SingleSelectColumnsSimple):
|
||||
"""Row of radio buttons. When tabbing over a selection, it is auto selected."""
|
||||
|
||||
def when_cursor_moved(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
self.h_exit_down("bye bye")
|
||||
|
||||
|
||||
class TextBoxInner(npyscreen.MultiLineEdit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -324,55 +333,6 @@ class TextBoxInner(npyscreen.MultiLineEdit):
|
||||
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
|
||||
self.h_paste()
|
||||
|
||||
# def update(self, clear=True):
|
||||
# if clear:
|
||||
# self.clear()
|
||||
|
||||
# HEIGHT = self.height
|
||||
# WIDTH = self.width
|
||||
# # draw box.
|
||||
# self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
# self.parent.curses_pad.hline(
|
||||
# self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
|
||||
# )
|
||||
# self.parent.curses_pad.vline(
|
||||
# self.rely, self.relx, curses.ACS_VLINE, self.height
|
||||
# )
|
||||
# self.parent.curses_pad.vline(
|
||||
# self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
|
||||
# )
|
||||
|
||||
# # draw corners
|
||||
# self.parent.curses_pad.addch(
|
||||
# self.rely,
|
||||
# self.relx,
|
||||
# curses.ACS_ULCORNER,
|
||||
# )
|
||||
# self.parent.curses_pad.addch(
|
||||
# self.rely,
|
||||
# self.relx + WIDTH,
|
||||
# curses.ACS_URCORNER,
|
||||
# )
|
||||
# self.parent.curses_pad.addch(
|
||||
# self.rely + HEIGHT,
|
||||
# self.relx,
|
||||
# curses.ACS_LLCORNER,
|
||||
# )
|
||||
# self.parent.curses_pad.addch(
|
||||
# self.rely + HEIGHT,
|
||||
# self.relx + WIDTH,
|
||||
# curses.ACS_LRCORNER,
|
||||
# )
|
||||
|
||||
# # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
# (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
|
||||
# self.relx += 1
|
||||
# self.rely += 1
|
||||
# self.height -= 1
|
||||
# self.width -= 1
|
||||
# super().update(clear=False)
|
||||
# (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)
|
||||
|
||||
|
||||
class TextBox(npyscreen.BoxTitle):
|
||||
_contained_widget = TextBoxInner
|
||||
|
@ -9,8 +9,8 @@ module.exports = {
|
||||
'plugin:@typescript-eslint/recommended',
|
||||
'plugin:react/recommended',
|
||||
'plugin:react-hooks/recommended',
|
||||
'plugin:prettier/recommended',
|
||||
'plugin:react/jsx-runtime',
|
||||
'prettier',
|
||||
],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: {
|
||||
@ -23,6 +23,11 @@ module.exports = {
|
||||
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
|
||||
root: true,
|
||||
rules: {
|
||||
curly: 'error',
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{ props: 'never', children: 'never' },
|
||||
],
|
||||
'react-hooks/exhaustive-deps': 'error',
|
||||
'no-var': 'error',
|
||||
'brace-style': 'error',
|
||||
@ -34,7 +39,6 @@ module.exports = {
|
||||
'warn',
|
||||
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
|
||||
],
|
||||
'prettier/prettier': ['error', { endOfLine: 'auto' }],
|
||||
'@typescript-eslint/ban-ts-comment': 'warn',
|
||||
'@typescript-eslint/no-explicit-any': 'warn',
|
||||
'@typescript-eslint/no-empty-interface': [
|
||||
|
@ -29,12 +29,13 @@
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
"lint:prettier": "prettier --check .",
|
||||
"lint:tsc": "tsc --noEmit",
|
||||
"lint": "yarn run lint:eslint && yarn run lint:prettier && yarn run lint:tsc && yarn run lint:madge",
|
||||
"lint": "concurrently -g -n eslint,prettier,tsc,madge -c cyan,green,magenta,yellow \"yarn run lint:eslint\" \"yarn run lint:prettier\" \"yarn run lint:tsc\" \"yarn run lint:madge\"",
|
||||
"fix": "eslint --fix . && prettier --loglevel warn --write . && tsc --noEmit",
|
||||
"lint-staged": "lint-staged",
|
||||
"postinstall": "patch-package && yarn run theme",
|
||||
"theme": "chakra-cli tokens src/theme/theme.ts",
|
||||
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch"
|
||||
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch",
|
||||
"up": "yarn upgrade-interactive --latest"
|
||||
},
|
||||
"madge": {
|
||||
"detectiveOptions": {
|
||||
@ -54,7 +55,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/anatomy": "^2.2.0",
|
||||
"@chakra-ui/icons": "^2.0.19",
|
||||
"@chakra-ui/icons": "^2.1.0",
|
||||
"@chakra-ui/react": "^2.8.0",
|
||||
"@chakra-ui/styled-system": "^2.9.1",
|
||||
"@chakra-ui/theme-tools": "^2.1.0",
|
||||
@ -65,55 +66,55 @@
|
||||
"@emotion/react": "^11.11.1",
|
||||
"@emotion/styled": "^11.11.0",
|
||||
"@floating-ui/react-dom": "^2.0.1",
|
||||
"@fontsource-variable/inter": "^5.0.3",
|
||||
"@fontsource/inter": "^5.0.3",
|
||||
"@mantine/core": "^6.0.14",
|
||||
"@mantine/form": "^6.0.15",
|
||||
"@mantine/hooks": "^6.0.14",
|
||||
"@fontsource-variable/inter": "^5.0.8",
|
||||
"@fontsource/inter": "^5.0.8",
|
||||
"@mantine/core": "^6.0.19",
|
||||
"@mantine/form": "^6.0.19",
|
||||
"@mantine/hooks": "^6.0.19",
|
||||
"@nanostores/react": "^0.7.1",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"downshift": "^7.6.0",
|
||||
"formik": "^2.4.2",
|
||||
"framer-motion": "^10.12.17",
|
||||
"formik": "^2.4.3",
|
||||
"framer-motion": "^10.16.1",
|
||||
"fuse.js": "^6.6.2",
|
||||
"i18next": "^23.2.3",
|
||||
"i18next": "^23.4.4",
|
||||
"i18next-browser-languagedetector": "^7.0.2",
|
||||
"i18next-http-backend": "^2.2.1",
|
||||
"konva": "^9.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanostores": "^0.9.2",
|
||||
"openapi-fetch": "^0.6.1",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
"openapi-fetch": "^0.7.4",
|
||||
"overlayscrollbars": "^2.2.0",
|
||||
"overlayscrollbars-react": "^0.5.0",
|
||||
"patch-package": "^7.0.0",
|
||||
"patch-package": "^8.0.0",
|
||||
"query-string": "^8.1.0",
|
||||
"re-resizable": "^6.9.9",
|
||||
"react": "^18.2.0",
|
||||
"react-colorful": "^5.6.1",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
"react-hotkeys-hook": "4.4.0",
|
||||
"react-i18next": "^13.0.1",
|
||||
"react-error-boundary": "^4.0.11",
|
||||
"react-hotkeys-hook": "4.4.1",
|
||||
"react-i18next": "^13.1.2",
|
||||
"react-icons": "^4.10.1",
|
||||
"react-konva": "^18.2.10",
|
||||
"react-redux": "^8.1.1",
|
||||
"react-resizable-panels": "^0.0.52",
|
||||
"react-redux": "^8.1.2",
|
||||
"react-resizable-panels": "^0.0.55",
|
||||
"react-use": "^17.4.0",
|
||||
"react-virtuoso": "^4.3.11",
|
||||
"react-virtuoso": "^4.5.0",
|
||||
"react-zoom-pan-pinch": "^3.0.8",
|
||||
"reactflow": "^11.7.4",
|
||||
"reactflow": "^11.8.3",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^3.3.1",
|
||||
"roarr": "^7.15.0",
|
||||
"serialize-error": "^11.0.0",
|
||||
"socket.io-client": "^4.7.0",
|
||||
"redux-remember": "^4.0.1",
|
||||
"roarr": "^7.15.1",
|
||||
"serialize-error": "^11.0.1",
|
||||
"socket.io-client": "^4.7.2",
|
||||
"use-debounce": "^9.0.4",
|
||||
"use-image": "^1.1.1",
|
||||
"uuid": "^9.0.0",
|
||||
"zod": "^3.21.4"
|
||||
"zod": "^3.22.2",
|
||||
"zod-validation-error": "^1.5.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@chakra-ui/cli": "^2.4.0",
|
||||
@ -126,38 +127,36 @@
|
||||
"@chakra-ui/cli": "^2.4.1",
|
||||
"@types/dateformat": "^5.0.0",
|
||||
"@types/lodash-es": "^4.14.194",
|
||||
"@types/node": "^20.3.1",
|
||||
"@types/react": "^18.2.14",
|
||||
"@types/node": "^20.5.1",
|
||||
"@types/react": "^18.2.20",
|
||||
"@types/react-dom": "^18.2.6",
|
||||
"@types/react-redux": "^7.1.25",
|
||||
"@types/react-transition-group": "^4.4.6",
|
||||
"@types/uuid": "^9.0.2",
|
||||
"@typescript-eslint/eslint-plugin": "^5.60.0",
|
||||
"@typescript-eslint/parser": "^5.60.0",
|
||||
"@typescript-eslint/eslint-plugin": "^6.4.1",
|
||||
"@typescript-eslint/parser": "^6.4.1",
|
||||
"@vitejs/plugin-react-swc": "^3.3.2",
|
||||
"axios": "^1.4.0",
|
||||
"babel-plugin-transform-imports": "^2.0.0",
|
||||
"concurrently": "^8.2.0",
|
||||
"eslint": "^8.43.0",
|
||||
"eslint-config-prettier": "^8.8.0",
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint": "^8.47.0",
|
||||
"eslint-config-prettier": "^9.0.0",
|
||||
"eslint-plugin-prettier": "^5.0.0",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"form-data": "^4.0.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.2.2",
|
||||
"lint-staged": "^14.0.1",
|
||||
"madge": "^6.1.0",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^6.2.8",
|
||||
"openapi-typescript-codegen": "^0.24.0",
|
||||
"openapi-typescript": "^6.5.2",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.8",
|
||||
"prettier": "^3.0.2",
|
||||
"rollup-plugin-visualizer": "^5.9.2",
|
||||
"terser": "^5.18.1",
|
||||
"ts-toolbelt": "^9.6.0",
|
||||
"vite": "^4.3.9",
|
||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||
"vite-plugin-dts": "^2.3.0",
|
||||
"vite": "^4.4.9",
|
||||
"vite-plugin-css-injected-by-js": "^3.3.0",
|
||||
"vite-plugin-dts": "^3.5.2",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.2.0",
|
||||
"yarn": "^1.22.19"
|
||||
|
@ -19,7 +19,7 @@
|
||||
"toggleAutoscroll": "Toggle autoscroll",
|
||||
"toggleLogViewer": "Toggle Log Viewer",
|
||||
"showGallery": "Show Gallery",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"showOptionsPanel": "Show Side Panel",
|
||||
"menu": "Menu"
|
||||
},
|
||||
"common": {
|
||||
@ -52,7 +52,7 @@
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"nodes": "Workflow Editor",
|
||||
"batch": "Batch Manager",
|
||||
"modelManager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
@ -95,7 +95,6 @@
|
||||
"statusModelConverted": "Model Converted",
|
||||
"statusMergingModels": "Merging Models",
|
||||
"statusMergedModels": "Models Merged",
|
||||
"pinOptionsPanel": "Pin Options Panel",
|
||||
"loading": "Loading",
|
||||
"loadingInvokeAI": "Loading Invoke AI",
|
||||
"random": "Random",
|
||||
@ -116,7 +115,6 @@
|
||||
"maintainAspectRatio": "Maintain Aspect Ratio",
|
||||
"autoSwitchNewImages": "Auto-Switch to New Images",
|
||||
"singleColumnLayout": "Single Column Layout",
|
||||
"pinGallery": "Pin Gallery",
|
||||
"allImagesLoaded": "All Images Loaded",
|
||||
"loadMore": "Load More",
|
||||
"noImagesInGallery": "No Images to Display",
|
||||
@ -133,6 +131,7 @@
|
||||
"generalHotkeys": "General Hotkeys",
|
||||
"galleryHotkeys": "Gallery Hotkeys",
|
||||
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
|
||||
"nodesHotkeys": "Nodes Hotkeys",
|
||||
"invoke": {
|
||||
"title": "Invoke",
|
||||
"desc": "Generate an image"
|
||||
@ -332,6 +331,10 @@
|
||||
"acceptStagingImage": {
|
||||
"title": "Accept Staging Image",
|
||||
"desc": "Accept Current Staging Area Image"
|
||||
},
|
||||
"addNodes": {
|
||||
"title": "Add Nodes",
|
||||
"desc": "Opens the add node menu"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@ -506,12 +509,9 @@
|
||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||
"maskBlur": "Mask Blur",
|
||||
"maskBlurMethod": "Mask Blur Method",
|
||||
"seamPaintingHeader": "Seam Painting",
|
||||
"seamSize": "Seam Size",
|
||||
"seamBlur": "Seam Blur",
|
||||
"seamSteps": "Seam Steps",
|
||||
"seamStrength": "Seam Strength",
|
||||
"seamThreshold": "Seam Threshold",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceSteps": "Coherence Pass Steps",
|
||||
"coherenceStrength": "Coherence Pass Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
"seamHighThreshold": "High",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
@ -572,7 +572,7 @@
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
|
||||
"resetComplete": "Web UI has been reset.",
|
||||
"consoleLogLevel": "Log Level",
|
||||
"shouldLogToConsole": "Console Logging",
|
||||
"developer": "Developer",
|
||||
@ -715,11 +715,12 @@
|
||||
"swapSizes": "Swap Sizes"
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveGraph": "Save Graph",
|
||||
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
||||
"clearGraph": "Clear Graph",
|
||||
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
||||
"reloadNodeTemplates": "Reload Node Templates",
|
||||
"saveWorkflow": "Save Workflow",
|
||||
"loadWorkflow": "Load Workflow",
|
||||
"resetWorkflow": "Reset Workflow",
|
||||
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
||||
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
|
@ -27,22 +27,10 @@ async function main() {
|
||||
* field accepts connection input. If it does, we can make the field optional.
|
||||
*/
|
||||
|
||||
// Check if we are generating types for an invocation
|
||||
const isInvocationPath = metadata.path.match(
|
||||
/^#\/components\/schemas\/\w*Invocation$/
|
||||
);
|
||||
|
||||
const hasInvocationProperties =
|
||||
schemaObject.properties &&
|
||||
['id', 'is_intermediate', 'type'].every(
|
||||
(prop) => prop in schemaObject.properties
|
||||
);
|
||||
|
||||
if (isInvocationPath && hasInvocationProperties) {
|
||||
if ('class' in schemaObject && schemaObject.class === 'invocation') {
|
||||
// We only want to make fields optional if they are required
|
||||
if (!Array.isArray(schemaObject?.required)) {
|
||||
schemaObject.required = ['id', 'type'];
|
||||
return;
|
||||
schemaObject.required = [];
|
||||
}
|
||||
|
||||
schemaObject.required.forEach((prop) => {
|
||||
@ -61,19 +49,13 @@ async function main() {
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
schemaObject.required = [
|
||||
...new Set(schemaObject.required.concat(['id', 'type'])),
|
||||
];
|
||||
|
||||
return;
|
||||
}
|
||||
// if (
|
||||
// 'input' in schemaObject &&
|
||||
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
|
||||
// ) {
|
||||
// schemaObject.required = false;
|
||||
// }
|
||||
|
||||
// Check if we are generating types for an invocation output
|
||||
if ('class' in schemaObject && schemaObject.class === 'output') {
|
||||
// modify output types
|
||||
}
|
||||
},
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { Flex, Grid } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -6,17 +6,15 @@ import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { ReactNode, memo, useEffect } from 'react';
|
||||
import { ReactNode, memo, useCallback, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
|
||||
@ -30,8 +28,13 @@ interface Props {
|
||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const logger = useLogger();
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
const handleReset = useCallback(() => {
|
||||
localStorage.clear();
|
||||
location.reload();
|
||||
return false;
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
@ -39,7 +42,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
|
||||
useEffect(() => {
|
||||
if (size(config)) {
|
||||
logger.info({ namespace: 'App', config }, 'Received config');
|
||||
logger.info({ config }, 'Received config');
|
||||
dispatch(configChanged(config));
|
||||
}
|
||||
}, [dispatch, config, logger]);
|
||||
@ -49,7 +52,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ErrorBoundary
|
||||
onReset={handleReset}
|
||||
FallbackComponent={AppErrorBoundaryFallback}
|
||||
>
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
<ImageUploader>
|
||||
<Grid
|
||||
@ -73,21 +79,12 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
</Grid>
|
||||
<DeleteImageModal />
|
||||
<ChangeBoardModal />
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,97 @@
|
||||
import { Flex, Heading, Link, Text, useToast } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import newGithubIssueUrl from 'new-github-issue-url';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaCopy, FaExternalLinkAlt } from 'react-icons/fa';
|
||||
import { FaArrowRotateLeft } from 'react-icons/fa6';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
type Props = {
|
||||
error: Error;
|
||||
resetErrorBoundary: () => void;
|
||||
};
|
||||
|
||||
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
||||
const toast = useToast();
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
const text = JSON.stringify(serializeError(error), null, 2);
|
||||
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
|
||||
toast({
|
||||
title: 'Error Copied',
|
||||
});
|
||||
}, [error, toast]);
|
||||
|
||||
const url = useMemo(
|
||||
() =>
|
||||
newGithubIssueUrl({
|
||||
user: 'invoke-ai',
|
||||
repo: 'InvokeAI',
|
||||
template: 'BUG_REPORT.yml',
|
||||
title: `[bug]: ${error.name}: ${error.message}`,
|
||||
}),
|
||||
[error.message, error.name]
|
||||
);
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="body"
|
||||
sx={{
|
||||
w: '100vw',
|
||||
h: '100vh',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
p: 4,
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
layerStyle="first"
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
borderRadius: 'base',
|
||||
justifyContent: 'center',
|
||||
gap: 8,
|
||||
p: 16,
|
||||
}}
|
||||
>
|
||||
<Heading>Something went wrong</Heading>
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
sx={{
|
||||
px: 8,
|
||||
py: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 600,
|
||||
color: 'error.500',
|
||||
_dark: { color: 'error.400' },
|
||||
}}
|
||||
>
|
||||
{error.name}: {error.message}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex sx={{ gap: 4 }}>
|
||||
<IAIButton
|
||||
leftIcon={<FaArrowRotateLeft />}
|
||||
onClick={resetErrorBoundary}
|
||||
>
|
||||
Reset UI
|
||||
</IAIButton>
|
||||
<IAIButton leftIcon={<FaCopy />} onClick={handleCopy}>
|
||||
Copy Error
|
||||
</IAIButton>
|
||||
<Link href={url} isExternal>
|
||||
<IAIButton leftIcon={<FaExternalLinkAlt />}>Create Issue</IAIButton>
|
||||
</Link>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AppErrorBoundaryFallback);
|
@ -1,30 +1,21 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import {
|
||||
ctrlKeyPressed,
|
||||
metaKeyPressed,
|
||||
shiftKeyPressed,
|
||||
} from 'features/ui/store/hotkeysSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
toggleGalleryPanel,
|
||||
toggleParametersPanel,
|
||||
togglePinGalleryPanel,
|
||||
togglePinParametersPanel,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import React, { memo } from 'react';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
const globalHotkeysSelector = createSelector(
|
||||
[stateSelector],
|
||||
({ hotkeys, ui }) => {
|
||||
({ hotkeys }) => {
|
||||
const { shift, ctrl, meta } = hotkeys;
|
||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
||||
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
|
||||
return { shift, ctrl, meta };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@ -41,9 +32,7 @@ const globalHotkeysSelector = createSelector(
|
||||
*/
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
|
||||
useAppSelector(globalHotkeysSelector);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
|
||||
|
||||
useHotkeys(
|
||||
'*',
|
||||
@ -68,34 +57,6 @@ const GlobalHotkeys: React.FC = () => {
|
||||
[shift, ctrl, meta]
|
||||
);
|
||||
|
||||
useHotkeys('o', () => {
|
||||
dispatch(toggleParametersPanel());
|
||||
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys(['shift+o'], () => {
|
||||
dispatch(togglePinParametersPanel());
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys('g', () => {
|
||||
dispatch(toggleGalleryPanel());
|
||||
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys(['shift+g'], () => {
|
||||
dispatch(togglePinGalleryPanel());
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(requestCanvasRescale());
|
||||
}
|
||||
});
|
||||
|
||||
useHotkeys('1', () => {
|
||||
dispatch(setActiveTab('txt2img'));
|
||||
});
|
||||
@ -112,6 +73,10 @@ const GlobalHotkeys: React.FC = () => {
|
||||
dispatch(setActiveTab('nodes'));
|
||||
});
|
||||
|
||||
useHotkeys('5', () => {
|
||||
dispatch(setActiveTab('modelManager'));
|
||||
});
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@ import {
|
||||
createLocalStorageManager,
|
||||
extendTheme,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode, useEffect, useMemo } from 'react';
|
||||
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
|
||||
@ -46,4 +46,4 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
);
|
||||
}
|
||||
|
||||
export default ThemeLocaleProvider;
|
||||
export default memo(ThemeLocaleProvider);
|
||||
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
|
||||
/**
|
||||
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
||||
@ -44,4 +44,4 @@ export const useAppToaster = () => {
|
||||
return toaster;
|
||||
};
|
||||
|
||||
export default Toaster;
|
||||
export default memo(Toaster);
|
||||
|
@ -9,7 +9,7 @@ export const log = Roarr.child(BASE_CONTEXT);
|
||||
|
||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||
|
||||
type LoggerNamespace =
|
||||
export type LoggerNamespace =
|
||||
| 'images'
|
||||
| 'models'
|
||||
| 'config'
|
||||
|
@ -1,12 +1,17 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useEffect } from 'react';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP } from './logger';
|
||||
import {
|
||||
$logger,
|
||||
BASE_CONTEXT,
|
||||
LOG_LEVEL_MAP,
|
||||
LoggerNamespace,
|
||||
logger,
|
||||
} from './logger';
|
||||
|
||||
const selector = createSelector(
|
||||
systemSelector,
|
||||
@ -25,7 +30,7 @@ const selector = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
export const useLogger = () => {
|
||||
export const useLogger = (namespace: LoggerNamespace) => {
|
||||
const { consoleLogLevel, shouldLogToConsole } = useAppSelector(selector);
|
||||
|
||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||
@ -57,7 +62,7 @@ export const useLogger = () => {
|
||||
$logger.set(Roarr.child(newContext));
|
||||
}, []);
|
||||
|
||||
const logger = useStore($logger);
|
||||
const log = useMemo(() => logger(namespace), [namespace]);
|
||||
|
||||
return logger;
|
||||
return log;
|
||||
};
|
||||
|
@ -1,13 +1,17 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import {
|
||||
controlNetImageChanged,
|
||||
controlNetProcessedImageChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
@ -73,22 +77,61 @@ export const addRequestedSingleImageDeletionListener = () => {
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
|
||||
if (imageUsage.isCanvasImage) {
|
||||
dispatch(resetCanvas());
|
||||
}
|
||||
|
||||
if (imageUsage.isControlNetImage) {
|
||||
dispatch(controlNetReset());
|
||||
}
|
||||
imageDTOs.forEach((imageDTO) => {
|
||||
// reset init image if we deleted it
|
||||
if (
|
||||
getState().generation.initialImage?.imageName === imageDTO.image_name
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
||||
if (imageUsage.isInitialImage) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
// reset controlNets that use the deleted images
|
||||
forEach(getState().controlNet.controlNets, (controlNet) => {
|
||||
if (
|
||||
controlNet.controlImage === imageDTO.image_name ||
|
||||
controlNet.processedControlImage === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId: controlNet.controlNetId,
|
||||
controlImage: null,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId: controlNet.controlNetId,
|
||||
processedControlImage: null,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (imageUsage.isNodesImage) {
|
||||
dispatch(nodeEditorReset());
|
||||
}
|
||||
// reset nodes that use the deleted images
|
||||
getState().nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
fieldImageValueChanged({
|
||||
nodeId: node.data.id,
|
||||
fieldName: input.name,
|
||||
value: undefined,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Delete from server
|
||||
const { requestId } = dispatch(
|
||||
@ -154,17 +197,58 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
||||
dispatch(resetCanvas());
|
||||
}
|
||||
|
||||
if (imagesUsage.some((i) => i.isControlNetImage)) {
|
||||
dispatch(controlNetReset());
|
||||
}
|
||||
imageDTOs.forEach((imageDTO) => {
|
||||
// reset init image if we deleted it
|
||||
if (
|
||||
getState().generation.initialImage?.imageName ===
|
||||
imageDTO.image_name
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
||||
if (imagesUsage.some((i) => i.isInitialImage)) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
// reset controlNets that use the deleted images
|
||||
forEach(getState().controlNet.controlNets, (controlNet) => {
|
||||
if (
|
||||
controlNet.controlImage === imageDTO.image_name ||
|
||||
controlNet.processedControlImage === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId: controlNet.controlNetId,
|
||||
controlImage: null,
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId: controlNet.controlNetId,
|
||||
processedControlImage: null,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (imagesUsage.some((i) => i.isNodesImage)) {
|
||||
dispatch(nodeEditorReset());
|
||||
}
|
||||
// reset nodes that use the deleted images
|
||||
getState().nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (
|
||||
input.type === 'ImageField' &&
|
||||
input.value?.image_name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(
|
||||
fieldImageValueChanged({
|
||||
nodeId: node.data.id,
|
||||
fieldName: input.name,
|
||||
value: undefined,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import { size } from 'lodash-es';
|
||||
|
||||
export const addSocketConnectedEventListener = () => {
|
||||
startAppListening({
|
||||
@ -18,7 +19,7 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,8 @@ import {
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
export interface IAIButtonProps extends ButtonProps {
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
tooltip?: TooltipProps['label'];
|
||||
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
|
||||
isChecked?: boolean;
|
||||
children: ReactNode;
|
||||
}
|
||||
|
@ -34,14 +34,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
gap: 2,
|
||||
borderTopRadius: 'base',
|
||||
borderBottomRadius: isOpen ? 0 : 'base',
|
||||
bg: isOpen
|
||||
? mode('base.200', 'base.750')(colorMode)
|
||||
: mode('base.150', 'base.800')(colorMode),
|
||||
bg: mode('base.250', 'base.750')(colorMode),
|
||||
color: mode('base.900', 'base.100')(colorMode),
|
||||
_hover: {
|
||||
bg: isOpen
|
||||
? mode('base.250', 'base.700')(colorMode)
|
||||
: mode('base.200', 'base.750')(colorMode),
|
||||
bg: mode('base.300', 'base.700')(colorMode),
|
||||
},
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
@ -90,9 +86,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||
<Box
|
||||
sx={{
|
||||
p: 4,
|
||||
p: 2,
|
||||
pt: 3,
|
||||
borderBottomRadius: 'base',
|
||||
bg: 'base.100',
|
||||
bg: 'base.150',
|
||||
_dark: {
|
||||
bg: 'base.800',
|
||||
},
|
||||
|
@ -100,14 +100,18 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const handleMouseOver = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOver) onMouseOver(e);
|
||||
if (onMouseOver) {
|
||||
onMouseOver(e);
|
||||
}
|
||||
setIsHovered(true);
|
||||
},
|
||||
[onMouseOver]
|
||||
);
|
||||
const handleMouseOut = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOut) onMouseOut(e);
|
||||
if (onMouseOut) {
|
||||
onMouseOut(e);
|
||||
}
|
||||
setIsHovered(false);
|
||||
},
|
||||
[onMouseOut]
|
||||
@ -122,7 +126,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
? {}
|
||||
: {
|
||||
cursor: 'pointer',
|
||||
bg: mode('base.200', 'base.800')(colorMode),
|
||||
bg: mode('base.200', 'base.700')(colorMode),
|
||||
_hover: {
|
||||
bg: mode('base.300', 'base.650')(colorMode),
|
||||
color: mode('base.500', 'base.300')(colorMode),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Box, Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaExclamation } from 'react-icons/fa';
|
||||
|
||||
const IAIErrorLoadingImageFallback = () => {
|
||||
@ -39,4 +40,4 @@ const IAIErrorLoadingImageFallback = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAIErrorLoadingImageFallback;
|
||||
export default memo(IAIErrorLoadingImageFallback);
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Box, Skeleton } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const IAIFillSkeleton = () => {
|
||||
return (
|
||||
@ -27,4 +28,4 @@ const IAIFillSkeleton = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAIFillSkeleton;
|
||||
export default memo(IAIFillSkeleton);
|
||||
|
@ -9,8 +9,8 @@ import { memo } from 'react';
|
||||
|
||||
export type IAIIconButtonProps = IconButtonProps & {
|
||||
role?: string;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
tooltip?: TooltipProps['label'];
|
||||
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
|
||||
isChecked?: boolean;
|
||||
};
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Badge, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type ImageMetadataOverlayProps = {
|
||||
@ -26,4 +27,4 @@ const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageMetadataOverlay;
|
||||
export default memo(ImageMetadataOverlay);
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Box, Flex, Heading } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
type ImageUploadOverlayProps = {
|
||||
@ -87,4 +88,4 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
export default ImageUploadOverlay;
|
||||
export default memo(ImageUploadOverlay);
|
||||
|
@ -150,7 +150,9 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
{...getRootProps({ style: {} })}
|
||||
onKeyDown={(e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') return;
|
||||
if (e.key === ' ') {
|
||||
return;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Flex, Icon } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
const SelectImagePlaceholder = () => {
|
||||
@ -19,4 +20,4 @@ const SelectImagePlaceholder = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default SelectImagePlaceholder;
|
||||
export default memo(SelectImagePlaceholder);
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
@ -18,6 +19,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
opacity: isSelected ? 1 : 0.7,
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
pointerEvents: 'none',
|
||||
shadow: isSelected
|
||||
? isHovered
|
||||
? 'hoverSelected.light'
|
||||
@ -39,4 +41,4 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default SelectionOverlay;
|
||||
export default memo(SelectionOverlay);
|
||||
|
@ -2,71 +2,108 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
// import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { isInvocationNode } from 'features/nodes/types/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
import { forEach, map } from 'lodash-es';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
const selector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
(state, activeTabName) => {
|
||||
const { generation, system } = state;
|
||||
const { initialImage } = generation;
|
||||
const { generation, system, nodes } = state;
|
||||
const { initialImage, model } = generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
|
||||
let isReady = true;
|
||||
const reasonsWhyNotReady: string[] = [];
|
||||
const reasons: string[] = [];
|
||||
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('System Busy');
|
||||
reasons.push('System busy');
|
||||
}
|
||||
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('System Disconnected');
|
||||
reasons.push('System disconnected');
|
||||
}
|
||||
|
||||
// // Cannot generate variations without valid seed weights
|
||||
// if (
|
||||
// shouldGenerateVariations &&
|
||||
// (!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
|
||||
// ) {
|
||||
// isReady = false;
|
||||
// reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
||||
// }
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
reasons.push('No initial image selected');
|
||||
}
|
||||
|
||||
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
||||
if (!controlNet.model) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push(`ControlNet ${id} has no model selected.`);
|
||||
if (activeTabName === 'nodes' && nodes.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push('No nodes in graph');
|
||||
}
|
||||
});
|
||||
|
||||
// All good
|
||||
return { isReady, reasonsWhyNotReady };
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push('Missing node template');
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) =>
|
||||
edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push('Missing field template');
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
||||
reasons.push(
|
||||
`${node.data.label || nodeTemplate.title} -> ${
|
||||
field.label || fieldTemplate.title
|
||||
} missing input`
|
||||
);
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
if (!model) {
|
||||
reasons.push('No model selected');
|
||||
}
|
||||
|
||||
if (state.controlNet.isEnabled) {
|
||||
map(state.controlNet.controlNets).forEach((controlNet, i) => {
|
||||
if (!controlNet.isEnabled) {
|
||||
return;
|
||||
}
|
||||
if (!controlNet.model) {
|
||||
reasons.push(`ControlNet ${i + 1} has no model selected.`);
|
||||
}
|
||||
|
||||
if (
|
||||
!controlNet.controlImage ||
|
||||
(!controlNet.processedControlImage &&
|
||||
controlNet.processorType !== 'none')
|
||||
) {
|
||||
reasons.push(`ControlNet ${i + 1} has no control image`);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { isReady: !reasons.length, isProcessing, reasons };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
export const useIsReadyToInvoke = () => {
|
||||
const { isReady } = useAppSelector(readinessSelector);
|
||||
return isReady;
|
||||
const { isReady, isProcessing, reasons } = useAppSelector(selector);
|
||||
return { isReady, isProcessing, reasons };
|
||||
};
|
||||
|
@ -11,8 +11,14 @@ export default function useResolution():
|
||||
const tabletResolutions = ['md', 'lg'];
|
||||
const desktopResolutions = ['xl', '2xl'];
|
||||
|
||||
if (mobileResolutions.includes(breakpointValue)) return 'mobile';
|
||||
if (tabletResolutions.includes(breakpointValue)) return 'tablet';
|
||||
if (desktopResolutions.includes(breakpointValue)) return 'desktop';
|
||||
if (mobileResolutions.includes(breakpointValue)) {
|
||||
return 'mobile';
|
||||
}
|
||||
if (tabletResolutions.includes(breakpointValue)) {
|
||||
return 'tablet';
|
||||
}
|
||||
if (desktopResolutions.includes(breakpointValue)) {
|
||||
return 'desktop';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
||||
|
@ -0,0 +1,2 @@
|
||||
export const colorTokenToCssVar = (colorToken: string) =>
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
@ -6,7 +6,11 @@ export const dateComparator = (a: string, b: string) => {
|
||||
const dateB = new Date(b);
|
||||
|
||||
// sort in ascending order
|
||||
if (dateA > dateB) return 1;
|
||||
if (dateA < dateB) return -1;
|
||||
if (dateA > dateB) {
|
||||
return 1;
|
||||
}
|
||||
if (dateA < dateB) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
@ -5,7 +5,9 @@ type Base64AndCaption = {
|
||||
|
||||
const openBase64ImageInTab = (images: Base64AndCaption[]) => {
|
||||
const w = window.open('');
|
||||
if (!w) return;
|
||||
if (!w) {
|
||||
return;
|
||||
}
|
||||
|
||||
images.forEach((i) => {
|
||||
const image = new Image();
|
||||
|
@ -5,6 +5,7 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { isStagingSelector } from '../store/canvasSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const ClearCanvasHistoryButtonModal = () => {
|
||||
const isStaging = useAppSelector(isStagingSelector);
|
||||
@ -28,4 +29,4 @@ const ClearCanvasHistoryButtonModal = () => {
|
||||
</IAIAlertDialog>
|
||||
);
|
||||
};
|
||||
export default ClearCanvasHistoryButtonModal;
|
||||
export default memo(ClearCanvasHistoryButtonModal);
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { Box, chakra, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
canvasSelector,
|
||||
@ -9,7 +9,7 @@ import {
|
||||
import Konva from 'konva';
|
||||
import { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { Vector2d } from 'konva/lib/types';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef } from 'react';
|
||||
import { Layer, Stage } from 'react-konva';
|
||||
import useCanvasDragMove from '../hooks/useCanvasDragMove';
|
||||
import useCanvasHotkeys from '../hooks/useCanvasHotkeys';
|
||||
@ -18,6 +18,7 @@ import useCanvasMouseMove from '../hooks/useCanvasMouseMove';
|
||||
import useCanvasMouseOut from '../hooks/useCanvasMouseOut';
|
||||
import useCanvasMouseUp from '../hooks/useCanvasMouseUp';
|
||||
import useCanvasWheel from '../hooks/useCanvasZoom';
|
||||
import { canvasResized } from '../store/canvasSlice';
|
||||
import {
|
||||
setCanvasBaseLayer,
|
||||
setCanvasStage,
|
||||
@ -106,7 +107,8 @@ const IAICanvas = () => {
|
||||
shouldAntialias,
|
||||
} = useAppSelector(selector);
|
||||
useCanvasHotkeys();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const stageRef = useRef<Konva.Stage | null>(null);
|
||||
const canvasBaseLayerRef = useRef<Konva.Layer | null>(null);
|
||||
|
||||
@ -137,8 +139,30 @@ const IAICanvas = () => {
|
||||
const { handleDragStart, handleDragMove, handleDragEnd } =
|
||||
useCanvasDragMove();
|
||||
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) {
|
||||
return;
|
||||
}
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.contentBoxSize) {
|
||||
const { width, height } = entry.contentRect;
|
||||
dispatch(canvasResized({ width, height }));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(containerRef.current);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
id="canvas-container"
|
||||
ref={containerRef}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
height: '100%',
|
||||
@ -146,13 +170,18 @@ const IAICanvas = () => {
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Box sx={{ position: 'relative' }}>
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
// top: 0,
|
||||
// insetInlineStart: 0,
|
||||
}}
|
||||
>
|
||||
<ChakraStage
|
||||
tabIndex={-1}
|
||||
ref={canvasStageRefCallback}
|
||||
sx={{
|
||||
outline: 'none',
|
||||
// boxShadow: '0px 0px 0px 1px var(--border-color-light)',
|
||||
overflow: 'hidden',
|
||||
cursor: stageCursor ? stageCursor : undefined,
|
||||
canvas: {
|
||||
@ -213,11 +242,11 @@ const IAICanvas = () => {
|
||||
/>
|
||||
</Layer>
|
||||
</ChakraStage>
|
||||
<IAICanvasStatusText />
|
||||
<IAICanvasStagingAreaToolbar />
|
||||
</Box>
|
||||
<IAICanvasStatusText />
|
||||
<IAICanvasStagingAreaToolbar />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvas;
|
||||
export default memo(IAICanvas);
|
||||
|
@ -4,6 +4,7 @@ import { isEqual } from 'lodash-es';
|
||||
|
||||
import { Group, Rect } from 'react-konva';
|
||||
import { canvasSelector } from '../store/canvasSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
canvasSelector,
|
||||
@ -67,4 +68,4 @@ const IAICanvasBoundingBoxOverlay = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasBoundingBoxOverlay;
|
||||
export default memo(IAICanvasBoundingBoxOverlay);
|
||||
|
@ -6,7 +6,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { isEqual, range } from 'lodash-es';
|
||||
|
||||
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
|
||||
import { ReactNode, memo, useCallback, useLayoutEffect, useState } from 'react';
|
||||
import { Group, Line as KonvaLine } from 'react-konva';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -117,4 +117,4 @@ const IAICanvasGrid = () => {
|
||||
return <Group>{gridLines}</Group>;
|
||||
};
|
||||
|
||||
export default IAICanvasGrid;
|
||||
export default memo(IAICanvasGrid);
|
||||
|
@ -4,6 +4,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import useImage from 'use-image';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
import { $authToken } from 'services/api/client';
|
||||
import { memo } from 'react';
|
||||
|
||||
type IAICanvasImageProps = {
|
||||
canvasImage: CanvasImage;
|
||||
@ -25,4 +26,4 @@ const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||
return <Image x={x} y={y} image={image} listening={false} />;
|
||||
};
|
||||
|
||||
export default IAICanvasImage;
|
||||
export default memo(IAICanvasImage);
|
||||
|
@ -4,7 +4,7 @@ import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { memo, useEffect, useState } from 'react';
|
||||
import { Image as KonvaImage } from 'react-konva';
|
||||
import { canvasSelector } from '../store/canvasSelectors';
|
||||
|
||||
@ -66,4 +66,4 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
||||
) : null;
|
||||
};
|
||||
|
||||
export default IAICanvasIntermediateImage;
|
||||
export default memo(IAICanvasIntermediateImage);
|
||||
|
@ -7,7 +7,7 @@ import { Rect } from 'react-konva';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import Konva from 'konva';
|
||||
import { isNumber } from 'lodash-es';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
|
||||
export const canvasMaskCompositerSelector = createSelector(
|
||||
canvasSelector,
|
||||
@ -125,7 +125,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
}, [offset]);
|
||||
|
||||
useEffect(() => {
|
||||
if (fillPatternImage) return;
|
||||
if (fillPatternImage) {
|
||||
return;
|
||||
}
|
||||
const image = new Image();
|
||||
|
||||
image.onload = () => {
|
||||
@ -135,7 +137,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
}, [fillPatternImage, maskColorString]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!fillPatternImage) return;
|
||||
if (!fillPatternImage) {
|
||||
return;
|
||||
}
|
||||
fillPatternImage.src = getColoredSVG(maskColorString);
|
||||
}, [fillPatternImage, maskColorString]);
|
||||
|
||||
@ -151,8 +155,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
!isNumber(stageScale) ||
|
||||
!isNumber(stageDimensions.width) ||
|
||||
!isNumber(stageDimensions.height)
|
||||
)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Rect
|
||||
@ -172,4 +177,4 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasMaskCompositer;
|
||||
export default memo(IAICanvasMaskCompositer);
|
||||
|
@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
|
||||
|
||||
import { Group, Line } from 'react-konva';
|
||||
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const canvasLinesSelector = createSelector(
|
||||
[canvasSelector],
|
||||
@ -52,4 +53,4 @@ const IAICanvasLines = (props: InpaintingCanvasLinesProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasLines;
|
||||
export default memo(IAICanvasLines);
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
isCanvasFillRect,
|
||||
} from '../store/canvasTypes';
|
||||
import IAICanvasImage from './IAICanvasImage';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector],
|
||||
@ -33,7 +34,9 @@ const selector = createSelector(
|
||||
const IAICanvasObjectRenderer = () => {
|
||||
const { objects } = useAppSelector(selector);
|
||||
|
||||
if (!objects) return null;
|
||||
if (!objects) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Group name="outpainting-objects" listening={false}>
|
||||
@ -101,4 +104,4 @@ const IAICanvasObjectRenderer = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasObjectRenderer;
|
||||
export default memo(IAICanvasObjectRenderer);
|
||||
|
@ -1,89 +0,0 @@
|
||||
import { Flex, Spinner } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
canvasSelector,
|
||||
initialCanvasImageSelector,
|
||||
} from 'features/canvas/store/canvasSelectors';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
resizeCanvas,
|
||||
setCanvasContainerDimensions,
|
||||
setDoesCanvasNeedScaling,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useLayoutEffect, useRef } from 'react';
|
||||
|
||||
const canvasResizerSelector = createSelector(
|
||||
canvasSelector,
|
||||
initialCanvasImageSelector,
|
||||
activeTabNameSelector,
|
||||
(canvas, initialCanvasImage, activeTabName) => {
|
||||
const { doesCanvasNeedScaling, isCanvasInitialized } = canvas;
|
||||
return {
|
||||
doesCanvasNeedScaling,
|
||||
activeTabName,
|
||||
initialCanvasImage,
|
||||
isCanvasInitialized,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const IAICanvasResizer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
doesCanvasNeedScaling,
|
||||
activeTabName,
|
||||
initialCanvasImage,
|
||||
isCanvasInitialized,
|
||||
} = useAppSelector(canvasResizerSelector);
|
||||
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
window.setTimeout(() => {
|
||||
if (!ref.current) return;
|
||||
|
||||
const { clientWidth, clientHeight } = ref.current;
|
||||
|
||||
dispatch(
|
||||
setCanvasContainerDimensions({
|
||||
width: clientWidth,
|
||||
height: clientHeight,
|
||||
})
|
||||
);
|
||||
|
||||
if (!isCanvasInitialized) {
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
} else {
|
||||
dispatch(resizeCanvas());
|
||||
}
|
||||
|
||||
dispatch(setDoesCanvasNeedScaling(false));
|
||||
}, 0);
|
||||
}, [
|
||||
dispatch,
|
||||
initialCanvasImage,
|
||||
doesCanvasNeedScaling,
|
||||
activeTabName,
|
||||
isCanvasInitialized,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
ref={ref}
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 4,
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
}}
|
||||
>
|
||||
<Spinner thickness="2px" size="xl" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasResizer;
|
@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
|
||||
|
||||
import { Group, Rect } from 'react-konva';
|
||||
import IAICanvasImage from './IAICanvasImage';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector],
|
||||
@ -88,4 +89,4 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasStagingArea;
|
||||
export default memo(IAICanvasStagingArea);
|
||||
|
@ -13,7 +13,7 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@ -129,7 +129,9 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
currentStagingAreaImage?.imageName ?? skipToken
|
||||
);
|
||||
|
||||
if (!currentStagingAreaImage) return null;
|
||||
if (!currentStagingAreaImage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -138,11 +140,10 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
w="100%"
|
||||
align="center"
|
||||
justify="center"
|
||||
filter="drop-shadow(0 0.5rem 1rem rgba(0,0,0))"
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
<ButtonGroup isAttached>
|
||||
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||
<IAIIconButton
|
||||
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
|
||||
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
|
||||
@ -207,4 +208,4 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasStagingAreaToolbar;
|
||||
export default memo(IAICanvasStagingAreaToolbar);
|
||||
|
@ -7,6 +7,7 @@ import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import roundToHundreth from '../util/roundToHundreth';
|
||||
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
|
||||
import { memo } from 'react';
|
||||
|
||||
const warningColor = 'var(--invokeai-colors-warning-500)';
|
||||
|
||||
@ -162,4 +163,4 @@ const IAICanvasStatusText = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasStatusText;
|
||||
export default memo(IAICanvasStatusText);
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
COLOR_PICKER_SIZE,
|
||||
COLOR_PICKER_STROKE_RADIUS,
|
||||
} from '../util/constants';
|
||||
import { memo } from 'react';
|
||||
|
||||
const canvasBrushPreviewSelector = createSelector(
|
||||
canvasSelector,
|
||||
@ -134,7 +135,9 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
|
||||
clip,
|
||||
} = useAppSelector(canvasBrushPreviewSelector);
|
||||
|
||||
if (!shouldDrawBrushPreview) return null;
|
||||
if (!shouldDrawBrushPreview) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Group listening={false} {...clip} {...rest}>
|
||||
@ -206,4 +209,4 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasToolPreview;
|
||||
export default memo(IAICanvasToolPreview);
|
||||
|
@ -19,7 +19,7 @@ import { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { Vector2d } from 'konva/lib/types';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Group, Rect, Transformer } from 'react-konva';
|
||||
|
||||
@ -85,7 +85,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!transformerRef.current || !shapeRef.current) return;
|
||||
if (!transformerRef.current || !shapeRef.current) {
|
||||
return;
|
||||
}
|
||||
transformerRef.current.nodes([shapeRef.current]);
|
||||
transformerRef.current.getLayer()?.batchDraw();
|
||||
}, []);
|
||||
@ -133,7 +135,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
* not its width and height. We need to un-scale the width and height before
|
||||
* setting the values.
|
||||
*/
|
||||
if (!shapeRef.current) return;
|
||||
if (!shapeRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const rect = shapeRef.current;
|
||||
|
||||
@ -313,4 +317,4 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasBoundingBox;
|
||||
export default memo(IAICanvasBoundingBox);
|
||||
|
@ -20,6 +20,7 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -150,4 +151,4 @@ const IAICanvasMaskOptions = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasMaskOptions;
|
||||
export default memo(IAICanvasMaskOptions);
|
||||
|
@ -18,7 +18,7 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent } from 'react';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaWrench } from 'react-icons/fa';
|
||||
@ -163,4 +163,4 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasSettingsButtonPopover;
|
||||
export default memo(IAICanvasSettingsButtonPopover);
|
||||
|
@ -18,6 +18,7 @@ import {
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { clamp, isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -252,4 +253,4 @@ const IAICanvasToolChooserOptions = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasToolChooserOptions;
|
||||
export default memo(IAICanvasToolChooserOptions);
|
||||
|
@ -18,7 +18,6 @@ import {
|
||||
import {
|
||||
resetCanvas,
|
||||
resetCanvasView,
|
||||
resizeAndScaleCanvas,
|
||||
setIsMaskEnabled,
|
||||
setLayer,
|
||||
setTool,
|
||||
@ -48,6 +47,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const selector = createSelector(
|
||||
[systemSelector, canvasSelector, isStagingSelector],
|
||||
@ -166,7 +166,9 @@ const IAICanvasToolbar = () => {
|
||||
|
||||
const handleResetCanvasView = (shouldScaleTo1 = false) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
if (!canvasBaseLayer) return;
|
||||
if (!canvasBaseLayer) {
|
||||
return;
|
||||
}
|
||||
const clientRect = canvasBaseLayer.getClientRect({
|
||||
skipTransform: true,
|
||||
});
|
||||
@ -180,7 +182,6 @@ const IAICanvasToolbar = () => {
|
||||
|
||||
const handleResetCanvas = () => {
|
||||
dispatch(resetCanvas());
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
};
|
||||
|
||||
const handleMergeVisible = () => {
|
||||
@ -309,4 +310,4 @@ const IAICanvasToolbar = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default IAICanvasToolbar;
|
||||
export default memo(IAICanvasToolbar);
|
||||
|
@ -32,13 +32,17 @@ const useCanvasDrag = () => {
|
||||
|
||||
return {
|
||||
handleDragStart: useCallback(() => {
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||
return;
|
||||
}
|
||||
dispatch(setIsMovingStage(true));
|
||||
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
||||
|
||||
handleDragMove: useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newCoordinates = { x: e.target.x(), y: e.target.y() };
|
||||
|
||||
@ -48,7 +52,9 @@ const useCanvasDrag = () => {
|
||||
),
|
||||
|
||||
handleDragEnd: useCallback(() => {
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||
return;
|
||||
}
|
||||
dispatch(setIsMovingStage(false));
|
||||
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
||||
};
|
||||
|
@ -134,7 +134,9 @@ const useInpaintingCanvasHotkeys = () => {
|
||||
useHotkeys(
|
||||
['space'],
|
||||
(e: KeyboardEvent) => {
|
||||
if (e.repeat) return;
|
||||
if (e.repeat) {
|
||||
return;
|
||||
}
|
||||
|
||||
canvasStage?.container().focus();
|
||||
|
||||
|
@ -38,7 +38,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
|
||||
return useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
if (!stageRef.current) return;
|
||||
if (!stageRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
stageRef.current.container().focus();
|
||||
|
||||
@ -54,7 +56,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
|
||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||
|
||||
if (!scaledCursorPosition) return;
|
||||
if (!scaledCursorPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
e.evt.preventDefault();
|
||||
|
||||
|
@ -41,11 +41,15 @@ const useCanvasMouseMove = (
|
||||
const { updateColorUnderCursor } = useColorPicker();
|
||||
|
||||
return useCallback(() => {
|
||||
if (!stageRef.current) return;
|
||||
if (!stageRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||
|
||||
if (!scaledCursorPosition) return;
|
||||
if (!scaledCursorPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(setCursorPosition(scaledCursorPosition));
|
||||
|
||||
@ -56,7 +60,9 @@ const useCanvasMouseMove = (
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isDrawing || tool === 'move' || isStaging) return;
|
||||
if (!isDrawing || tool === 'move' || isStaging) {
|
||||
return;
|
||||
}
|
||||
|
||||
didMouseMoveRef.current = true;
|
||||
dispatch(
|
||||
|
@ -47,7 +47,9 @@ const useCanvasMouseUp = (
|
||||
if (!didMouseMoveRef.current && isDrawing && stageRef.current) {
|
||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||
|
||||
if (!scaledCursorPosition) return;
|
||||
if (!scaledCursorPosition) {
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extend the current line.
|
||||
|
@ -35,13 +35,17 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
||||
return useCallback(
|
||||
(e: KonvaEventObject<WheelEvent>) => {
|
||||
// stop default scrolling
|
||||
if (!stageRef.current || isMoveStageKeyHeld) return;
|
||||
if (!stageRef.current || isMoveStageKeyHeld) {
|
||||
return;
|
||||
}
|
||||
|
||||
e.evt.preventDefault();
|
||||
|
||||
const cursorPos = stageRef.current.getPointerPosition();
|
||||
|
||||
if (!cursorPos) return;
|
||||
if (!cursorPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
const mousePointTo = {
|
||||
x: (cursorPos.x - stageRef.current.x()) / stageScale,
|
||||
|
@ -16,11 +16,15 @@ const useColorPicker = () => {
|
||||
|
||||
return {
|
||||
updateColorUnderCursor: () => {
|
||||
if (!stage || !canvasBaseLayer) return;
|
||||
if (!stage || !canvasBaseLayer) {
|
||||
return;
|
||||
}
|
||||
|
||||
const position = stage.getPointerPosition();
|
||||
|
||||
if (!position) return;
|
||||
if (!position) {
|
||||
return;
|
||||
}
|
||||
|
||||
const pixelRatio = Konva.pixelRatio;
|
||||
|
||||
|
@ -3,8 +3,4 @@ import { CanvasState } from './canvasTypes';
|
||||
/**
|
||||
* Canvas slice persist denylist
|
||||
*/
|
||||
export const canvasPersistDenylist: (keyof CanvasState)[] = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
];
|
||||
export const canvasPersistDenylist: (keyof CanvasState)[] = ['cursorPosition'];
|
||||
|
@ -5,10 +5,6 @@ import {
|
||||
roundToMultiple,
|
||||
} from 'common/util/roundDownToMultiple';
|
||||
import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldUseCanvasBetaLayout,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
@ -50,12 +46,9 @@ export const initialCanvasState: CanvasState = {
|
||||
boundingBoxScaleMethod: 'none',
|
||||
brushColor: { r: 90, g: 90, b: 255, a: 1 },
|
||||
brushSize: 50,
|
||||
canvasContainerDimensions: { width: 0, height: 0 },
|
||||
colorPickerColor: { r: 90, g: 90, b: 255, a: 1 },
|
||||
cursorPosition: null,
|
||||
doesCanvasNeedScaling: false,
|
||||
futureLayerStates: [],
|
||||
isCanvasInitialized: false,
|
||||
isDrawing: false,
|
||||
isMaskEnabled: true,
|
||||
isMouseOverBoundingBox: false,
|
||||
@ -208,7 +201,6 @@ export const canvasSlice = createSlice({
|
||||
};
|
||||
state.futureLayerStates = [];
|
||||
|
||||
state.isCanvasInitialized = false;
|
||||
const newScale = calculateScale(
|
||||
stageDimensions.width,
|
||||
stageDimensions.height,
|
||||
@ -228,7 +220,6 @@ export const canvasSlice = createSlice({
|
||||
);
|
||||
state.stageScale = newScale;
|
||||
state.stageCoordinates = newCoordinates;
|
||||
state.doesCanvasNeedScaling = true;
|
||||
},
|
||||
setBoundingBoxDimensions: (state, action: PayloadAction<Dimensions>) => {
|
||||
const newDimensions = roundDimensionsTo64(action.payload);
|
||||
@ -258,9 +249,6 @@ export const canvasSlice = createSlice({
|
||||
setBoundingBoxPreviewFill: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.boundingBoxPreviewFill = action.payload;
|
||||
},
|
||||
setDoesCanvasNeedScaling: (state, action: PayloadAction<boolean>) => {
|
||||
state.doesCanvasNeedScaling = action.payload;
|
||||
},
|
||||
setStageScale: (state, action: PayloadAction<number>) => {
|
||||
state.stageScale = action.payload;
|
||||
},
|
||||
@ -397,7 +385,9 @@ export const canvasSlice = createSlice({
|
||||
const { tool, layer, brushColor, brushSize, shouldRestrictStrokesToBox } =
|
||||
state;
|
||||
|
||||
if (tool === 'move' || tool === 'colorPicker') return;
|
||||
if (tool === 'move' || tool === 'colorPicker') {
|
||||
return;
|
||||
}
|
||||
|
||||
const newStrokeWidth = brushSize / 2;
|
||||
|
||||
@ -434,14 +424,18 @@ export const canvasSlice = createSlice({
|
||||
addPointToCurrentLine: (state, action: PayloadAction<number[]>) => {
|
||||
const lastLine = state.layerState.objects.findLast(isCanvasAnyLine);
|
||||
|
||||
if (!lastLine) return;
|
||||
if (!lastLine) {
|
||||
return;
|
||||
}
|
||||
|
||||
lastLine.points.push(...action.payload);
|
||||
},
|
||||
undo: (state) => {
|
||||
const targetState = state.pastLayerStates.pop();
|
||||
|
||||
if (!targetState) return;
|
||||
if (!targetState) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
||||
|
||||
@ -454,7 +448,9 @@ export const canvasSlice = createSlice({
|
||||
redo: (state) => {
|
||||
const targetState = state.futureLayerStates.shift();
|
||||
|
||||
if (!targetState) return;
|
||||
if (!targetState) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
@ -485,97 +481,14 @@ export const canvasSlice = createSlice({
|
||||
state.layerState = initialLayerState;
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
setCanvasContainerDimensions: (
|
||||
canvasResized: (
|
||||
state,
|
||||
action: PayloadAction<Dimensions>
|
||||
action: PayloadAction<{ width: number; height: number }>
|
||||
) => {
|
||||
state.canvasContainerDimensions = action.payload;
|
||||
},
|
||||
resizeAndScaleCanvas: (state) => {
|
||||
const { width: containerWidth, height: containerHeight } =
|
||||
state.canvasContainerDimensions;
|
||||
|
||||
const initialCanvasImage =
|
||||
state.layerState.objects.find(isCanvasBaseImage);
|
||||
|
||||
const { width, height } = action.payload;
|
||||
const newStageDimensions = {
|
||||
width: Math.floor(containerWidth),
|
||||
height: Math.floor(containerHeight),
|
||||
};
|
||||
|
||||
if (!initialCanvasImage) {
|
||||
const newScale = calculateScale(
|
||||
newStageDimensions.width,
|
||||
newStageDimensions.height,
|
||||
512,
|
||||
512,
|
||||
STAGE_PADDING_PERCENTAGE
|
||||
);
|
||||
|
||||
const newCoordinates = calculateCoordinates(
|
||||
newStageDimensions.width,
|
||||
newStageDimensions.height,
|
||||
0,
|
||||
0,
|
||||
512,
|
||||
512,
|
||||
newScale
|
||||
);
|
||||
|
||||
const newBoundingBoxDimensions = { width: 512, height: 512 };
|
||||
|
||||
state.stageScale = newScale;
|
||||
state.stageCoordinates = newCoordinates;
|
||||
state.stageDimensions = newStageDimensions;
|
||||
state.boundingBoxCoordinates = { x: 0, y: 0 };
|
||||
state.boundingBoxDimensions = newBoundingBoxDimensions;
|
||||
|
||||
if (state.boundingBoxScaleMethod === 'auto') {
|
||||
const scaledDimensions = getScaledBoundingBoxDimensions(
|
||||
newBoundingBoxDimensions
|
||||
);
|
||||
state.scaledBoundingBoxDimensions = scaledDimensions;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const { width: imageWidth, height: imageHeight } = initialCanvasImage;
|
||||
|
||||
const padding = 0.95;
|
||||
|
||||
const newScale = calculateScale(
|
||||
containerWidth,
|
||||
containerHeight,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
padding
|
||||
);
|
||||
|
||||
const newCoordinates = calculateCoordinates(
|
||||
newStageDimensions.width,
|
||||
newStageDimensions.height,
|
||||
0,
|
||||
0,
|
||||
imageWidth,
|
||||
imageHeight,
|
||||
newScale
|
||||
);
|
||||
|
||||
state.minimumStageScale = newScale;
|
||||
state.stageScale = newScale;
|
||||
state.stageCoordinates = floorCoordinates(newCoordinates);
|
||||
state.stageDimensions = newStageDimensions;
|
||||
|
||||
state.isCanvasInitialized = true;
|
||||
},
|
||||
resizeCanvas: (state) => {
|
||||
const { width: containerWidth, height: containerHeight } =
|
||||
state.canvasContainerDimensions;
|
||||
|
||||
const newStageDimensions = {
|
||||
width: Math.floor(containerWidth),
|
||||
height: Math.floor(containerHeight),
|
||||
width: Math.floor(width),
|
||||
height: Math.floor(height),
|
||||
};
|
||||
|
||||
state.stageDimensions = newStageDimensions;
|
||||
@ -868,14 +781,6 @@ export const canvasSlice = createSlice({
|
||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(setShouldUseCanvasBetaLayout, (state) => {
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
|
||||
builder.addCase(setActiveTab, (state) => {
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
builder.addCase(setAspectRatio, (state, action) => {
|
||||
const ratio = action.payload;
|
||||
if (ratio) {
|
||||
@ -907,8 +812,6 @@ export const {
|
||||
resetCanvas,
|
||||
resetCanvasInteractionState,
|
||||
resetCanvasView,
|
||||
resizeAndScaleCanvas,
|
||||
resizeCanvas,
|
||||
setBoundingBoxCoordinates,
|
||||
setBoundingBoxDimensions,
|
||||
setBoundingBoxPreviewFill,
|
||||
@ -916,10 +819,8 @@ export const {
|
||||
flipBoundingBoxAxes,
|
||||
setBrushColor,
|
||||
setBrushSize,
|
||||
setCanvasContainerDimensions,
|
||||
setColorPickerColor,
|
||||
setCursorPosition,
|
||||
setDoesCanvasNeedScaling,
|
||||
setInitialCanvasImage,
|
||||
setIsDrawing,
|
||||
setIsMaskEnabled,
|
||||
@ -958,6 +859,7 @@ export const {
|
||||
stagingAreaInitialized,
|
||||
canvasSessionIdChanged,
|
||||
setShouldAntialias,
|
||||
canvasResized,
|
||||
} = canvasSlice.actions;
|
||||
|
||||
export default canvasSlice.reducer;
|
||||
|
@ -126,12 +126,9 @@ export interface CanvasState {
|
||||
boundingBoxScaleMethod: BoundingBoxScale;
|
||||
brushColor: RgbaColor;
|
||||
brushSize: number;
|
||||
canvasContainerDimensions: Dimensions;
|
||||
colorPickerColor: RgbaColor;
|
||||
cursorPosition: Vector2d | null;
|
||||
doesCanvasNeedScaling: boolean;
|
||||
futureLayerStates: CanvasLayerState[];
|
||||
isCanvasInitialized: boolean;
|
||||
isDrawing: boolean;
|
||||
isMaskEnabled: boolean;
|
||||
isMouseOverBoundingBox: boolean;
|
||||
|
@ -1,16 +0,0 @@
|
||||
import { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { setDoesCanvasNeedScaling } from '../canvasSlice';
|
||||
|
||||
const debouncedCanvasScale = debounce((dispatch: AppDispatch) => {
|
||||
dispatch(setDoesCanvasNeedScaling(true));
|
||||
}, 300);
|
||||
|
||||
export const requestCanvasRescale =
|
||||
() => (dispatch: AppDispatch, getState: AppGetState) => {
|
||||
const activeTabName = activeTabNameSelector(getState());
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
debouncedCanvasScale(dispatch);
|
||||
}
|
||||
};
|
@ -5,7 +5,9 @@ const getScaledCursorPosition = (stage: Stage) => {
|
||||
|
||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||
|
||||
if (!pointerPosition || !stageTransform) return;
|
||||
if (!pointerPosition || !stageTransform) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||
|
||||
|
@ -80,19 +80,19 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
p: 3,
|
||||
p: 2,
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
bg: 'base.200',
|
||||
bg: 'base.250',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
bg: 'base.750',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<IAISwitch
|
||||
tooltip={'Toggle this ControlNet'}
|
||||
aria-label={'Toggle this ControlNet'}
|
||||
tooltip="Toggle this ControlNet"
|
||||
aria-label="Toggle this ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleToggleIsEnabled}
|
||||
/>
|
||||
@ -194,7 +194,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNet={controlNet} height={28} />
|
||||
<ControlNetImagePreview controlNet={controlNet} isSmall />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
@ -207,7 +207,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<ControlNetImagePreview controlNet={controlNet} height="392px" />
|
||||
<ControlNetImagePreview controlNet={controlNet} />
|
||||
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
|
||||
<ControlNetProcessorComponent controlNet={controlNet} />
|
||||
</>
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { Box, Flex, Spinner, SystemStyleObject } from '@chakra-ui/react';
|
||||
import { Box, Flex, Spinner } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'features/dnd/types';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
@ -21,7 +21,7 @@ import {
|
||||
|
||||
type Props = {
|
||||
controlNet: ControlNetConfig;
|
||||
height: SystemStyleObject['h'];
|
||||
isSmall?: boolean;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
@ -36,15 +36,14 @@ const selector = createSelector(
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ControlNetImagePreview = (props: Props) => {
|
||||
const { height } = props;
|
||||
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
const {
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
isEnabled,
|
||||
controlNetId,
|
||||
} = props.controlNet;
|
||||
} = controlNet;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@ -109,7 +108,7 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
sx={{
|
||||
position: 'relative',
|
||||
w: 'full',
|
||||
h: height,
|
||||
h: isSmall ? 28 : 366, // magic no touch
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
|
@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -36,4 +36,4 @@ const ParamControlNetFeatureToggle = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamControlNetFeatureToggle;
|
||||
export default memo(ParamControlNetFeatureToggle);
|
||||
|
@ -23,7 +23,7 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!isEnabled}
|
||||
label={'Weight'}
|
||||
label="Weight"
|
||||
value={weight}
|
||||
onChange={handleWeightChanged}
|
||||
min={0}
|
||||
|
@ -8,6 +8,7 @@ import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial
|
||||
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -40,4 +41,4 @@ const ParamDynamicPromptsCollapse = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamDynamicPromptsCollapse;
|
||||
export default memo(ParamDynamicPromptsCollapse);
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { combinatorialToggled } from '../store/dynamicPromptsSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -34,4 +34,4 @@ const ParamDynamicPromptsCombinatorial = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamDynamicPromptsCombinatorial;
|
||||
export default memo(ParamDynamicPromptsCombinatorial);
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { isEnabledToggled } from '../store/dynamicPromptsSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -33,4 +33,4 @@ const ParamDynamicPromptsToggle = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamDynamicPromptsToggle;
|
||||
export default memo(ParamDynamicPromptsToggle);
|
||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import {
|
||||
maxPromptsChanged,
|
||||
maxPromptsReset,
|
||||
@ -60,4 +60,4 @@ const ParamDynamicPromptsMaxPrompts = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamDynamicPromptsMaxPrompts;
|
||||
export default memo(ParamDynamicPromptsMaxPrompts);
|
||||
|
@ -13,7 +13,7 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
||||
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
|
||||
@ -118,7 +118,7 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
<IAIMantineSearchableSelect
|
||||
inputRef={inputRef}
|
||||
autoFocus
|
||||
placeholder={'Add Embedding'}
|
||||
placeholder="Add Embedding"
|
||||
value={null}
|
||||
data={data}
|
||||
nothingFound="No matching Embeddings"
|
||||
@ -140,4 +140,4 @@ const ParamEmbeddingPopover = (props: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamEmbeddingPopover;
|
||||
export default memo(ParamEmbeddingPopover);
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Badge, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const AutoAddIcon = () => {
|
||||
return (
|
||||
@ -20,4 +21,4 @@ const AutoAddIcon = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default AutoAddIcon;
|
||||
export default memo(AutoAddIcon);
|
||||
|
@ -6,7 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -66,7 +66,7 @@ const BoardAutoAddSelect = () => {
|
||||
label="Auto-Add Board"
|
||||
inputRef={inputRef}
|
||||
autoFocus
|
||||
placeholder={'Select a Board'}
|
||||
placeholder="Select a Board"
|
||||
value={autoAddBoardId}
|
||||
data={boards}
|
||||
nothingFound="No matching Boards"
|
||||
@ -81,4 +81,4 @@ const BoardAutoAddSelect = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default BoardAutoAddSelect;
|
||||
export default memo(BoardAutoAddSelect);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user