Merge branch 'main' into feat_compel_and

This commit is contained in:
Millun Atluri 2023-08-24 17:38:35 +10:00 committed by GitHub
commit 65feb92286
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
360 changed files with 10268 additions and 9620 deletions

37
.gitignore vendored
View File

@ -1,23 +1,8 @@
# ignore default image save location and model symbolic link
.idea/
embeddings/
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
**/restoration/codeformer/weights
# ignore user models config
configs/models.user.yaml
config/models.user.yml
invokeai.init
.version
.last_model
# ignore the Anaconda/Miniconda installer used while building Docker image
anaconda.sh
# ignore a directory which serves as a place for initial images
inputs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@ -189,39 +174,17 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
src
**/__pycache__/
outputs
# Logs and associated folders
# created from generated embeddings.
logs
testtube
checkpoints
# If it's a Mac
.DS_Store
invokeai/frontend/yarn.lock
invokeai/frontend/node_modules
# Let the frontend manage its own gitignore
!invokeai/frontend/web/*
# Scratch folder
.scratch/
.vscode/
gfpgan/
models/ldm/stable-diffusion-v1/*.sha256
# GFPGAN model files
gfpgan/
# config file (will be created by installer)
configs/models.yaml
# ignore initfile
.invokeai
# ignore environment.yml and requirements.txt
# these are links to the real files in environments-and-requirements

View File

@ -175,22 +175,27 @@ These configuration settings allow you to enable and disable various InvokeAI fe
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
### Memory/Performance
### Generation
These options tune InvokeAI's memory and performance characteristics.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
### Device
These options configure the generation execution device.
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
### Paths

View File

@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig:
infill_methods = ["tile"]
infill_methods = ["tile", "lama"]
if PatchMatch.patchmatch_available():
infill_methods.append("patchmatch")

View File

@ -122,6 +122,7 @@ def custom_openapi():
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas["definitions"].items():
output_schema["class"] = "output"
openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__
@ -130,8 +131,8 @@ def custom_openapi():
# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, output_schema in ui_config_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
@ -140,8 +141,8 @@ def custom_openapi():
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
from invokeai.backend.model_management.models import get_model_config_enums

View File

@ -71,6 +71,9 @@ class FieldDescriptions:
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
@ -140,6 +143,7 @@ class UIType(str, Enum):
# region Misc
FilePath = "FilePath"
Enum = "enum"
Scheduler = "Scheduler"
# endregion
@ -166,6 +170,7 @@ class _InputField(BaseModel):
ui_hidden: bool
ui_type: Optional[UIType]
ui_component: Optional[UIComponent]
ui_order: Optional[int]
class _OutputField(BaseModel):
@ -178,6 +183,7 @@ class _OutputField(BaseModel):
ui_hidden: bool
ui_type: Optional[UIType]
ui_order: Optional[int]
def InputField(
@ -211,6 +217,7 @@ def InputField(
ui_type: Optional[UIType] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
ui_order: Optional[int] = None,
**kwargs: Any,
) -> Any:
"""
@ -269,6 +276,7 @@ def InputField(
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
**kwargs,
)
@ -302,6 +310,7 @@ def OutputField(
repr: bool = True,
ui_type: Optional[UIType] = None,
ui_hidden: bool = False,
ui_order: Optional[int] = None,
**kwargs: Any,
) -> Any:
"""
@ -348,6 +357,7 @@ def OutputField(
repr=repr,
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
**kwargs,
)
@ -376,7 +386,7 @@ class BaseInvocationOutput(BaseModel):
"""Base class for all invocation outputs"""
# All outputs must include a type name like this:
# type: Literal['your_output_name']
# type: Literal['your_output_name'] # noqa f821
@classmethod
def get_all_subclasses_tuple(cls):
@ -389,6 +399,13 @@ class BaseInvocationOutput(BaseModel):
toprocess.extend(next_subclasses)
return tuple(subclasses)
class Config:
@staticmethod
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"].extend(["type"])
class RequiredConnectionException(Exception):
"""Raised when an field which requires a connection did not receive a value."""
@ -410,7 +427,7 @@ class BaseInvocation(ABC, BaseModel):
"""
# All invocations must include a type name like this:
# type: Literal['your_output_name']
# type: Literal['your_output_name'] # noqa f821
@classmethod
def get_all_subclasses(cls):
@ -449,6 +466,9 @@ class BaseInvocation(ABC, BaseModel):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"].extend(["type", "id"])
@abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
@ -485,7 +505,7 @@ class BaseInvocation(ABC, BaseModel):
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
)

View File

@ -232,7 +232,7 @@ class SDXLPromptInvocationBase:
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=True,
requires_pooled=get_pooled,
)
conjunction = Compel.parse_prompt_string(prompt)

View File

@ -8,7 +8,7 @@ import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@ -41,6 +41,39 @@ class ShowImageInvocation(BaseInvocation):
)
@title("Blank Image")
@tags("image")
class BlankImageInvocation(BaseInvocation):
"""Creates a blank image and forwards it to the pipeline"""
# Metadata
type: Literal["blank_image"] = "blank_image"
# Inputs
width: int = InputField(default=512, description="The width of the image")
height: int = InputField(default=512, description="The height of the image")
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@title("Crop Image")
@tags("image", "crop")
class ImageCropInvocation(BaseInvocation):

View File

@ -1,23 +1,25 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import math
from typing import Literal, Optional, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
"lama",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
@ -28,6 +30,11 @@ INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
def infill_lama(im: Image.Image) -> Image.Image:
lama = LaMA()
return lama(im)
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@ -90,7 +97,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask is False).sum()
replace_count = (tiles_mask == False).sum() # noqa: E712
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
@ -218,3 +225,34 @@ class InfillPatchMatchInvocation(BaseInvocation):
width=image_dto.width,
height=image_dto.height,
)
@title("LaMa Infill")
@tags("image", "inpaint")
class LaMaInfillInvocation(BaseInvocation):
"""Infills transparent areas of an image using the LaMa model"""
type: Literal["infill_lama"] = "infill_lama"
# Inputs
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
infilled = infill_lama(image.copy())
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -4,6 +4,7 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
@ -106,24 +107,28 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Inputs
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
)
latents: Optional[LatentsField] = InputField(
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[ImageField] = InputField(
default=None,
description=FieldDescriptions.mask,
@ -453,7 +458,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
@title("Latents to Image")
@tags("latents", "image", "vae")
@tags("latents", "image", "vae", "l2i")
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
@ -641,7 +646,7 @@ class ScaleLatentsInvocation(BaseInvocation):
@title("Image to Latents")
@tags("latents", "image", "vae")
@tags("latents", "image", "vae", "i2l")
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@ -720,3 +725,81 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@title("Blend Latents")
@tags("latents", "blend")
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
type: Literal["lblend"] = "lblend"
# Inputs
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.services.latents.get(self.latents_a.latents_name)
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
# TODO:
device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, blended_latents)
return build_latents_output(latents_name=name, latents=blended_latents)

View File

@ -21,7 +21,7 @@ class AddInvocation(BaseInvocation):
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a + self.b)
return IntegerOutput(value=self.a + self.b)
@title("Subtract Integers")
@ -36,7 +36,7 @@ class SubtractInvocation(BaseInvocation):
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a - self.b)
return IntegerOutput(value=self.a - self.b)
@title("Multiply Integers")
@ -51,7 +51,7 @@ class MultiplyInvocation(BaseInvocation):
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a * self.b)
return IntegerOutput(value=self.a * self.b)
@title("Divide Integers")
@ -66,7 +66,7 @@ class DivideInvocation(BaseInvocation):
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=int(self.a / self.b))
return IntegerOutput(value=int(self.a / self.b))
@title("Random Integer")
@ -81,4 +81,4 @@ class RandomIntInvocation(BaseInvocation):
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=np.random.randint(self.low, self.high))
return IntegerOutput(value=np.random.randint(self.low, self.high))

View File

@ -72,7 +72,7 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
@title("Main Model Loader")
@title("Main Model")
@tags("model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -179,7 +179,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
# fmt: on
@title("LoRA Loader")
@title("LoRA")
@tags("lora", "model")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@ -257,7 +257,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
# fmt: on
@title("SDXL LoRA Loader")
@title("SDXL LoRA")
@tags("sdxl", "lora", "model")
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@ -356,7 +356,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("VAE Loader")
@title("VAE")
@tags("vae", "model")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""

View File

@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
ui_type=UIType.Float,
)
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
)
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
unet: UNetField = InputField(
@ -406,7 +406,7 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type")
@title("ONNX Model Loader")
@title("ONNX Main Model")
@tags("onnx", "model")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""

View File

@ -2,8 +2,8 @@
from typing import Literal, Optional, Tuple
from pydantic import BaseModel, Field
import torch
from pydantic import BaseModel, Field
from .baseinvocation import (
BaseInvocation,
@ -33,7 +33,7 @@ class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
a: bool = OutputField(description="The output boolean")
value: bool = OutputField(description="The output boolean")
class BooleanCollectionOutput(BaseInvocationOutput):
@ -42,9 +42,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
)
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
@title("Boolean Primitive")
@ -55,10 +53,10 @@ class BooleanInvocation(BaseInvocation):
type: Literal["boolean"] = "boolean"
# Inputs
a: bool = InputField(default=False, description="The boolean value")
value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(a=self.a)
return BooleanOutput(value=self.value)
@title("Boolean Primitive Collection")
@ -70,7 +68,7 @@ class BooleanCollectionInvocation(BaseInvocation):
# Inputs
collection: list[bool] = InputField(
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
@ -86,7 +84,7 @@ class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
a: int = OutputField(description="The output integer")
value: int = OutputField(description="The output integer")
class IntegerCollectionOutput(BaseInvocationOutput):
@ -95,9 +93,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
)
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
@title("Integer Primitive")
@ -108,10 +104,10 @@ class IntegerInvocation(BaseInvocation):
type: Literal["integer"] = "integer"
# Inputs
a: int = InputField(default=0, description="The integer value")
value: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a)
return IntegerOutput(value=self.value)
@title("Integer Primitive Collection")
@ -139,7 +135,7 @@ class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
a: float = OutputField(description="The output float")
value: float = OutputField(description="The output float")
class FloatCollectionOutput(BaseInvocationOutput):
@ -148,9 +144,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
)
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
@title("Float Primitive")
@ -161,10 +155,10 @@ class FloatInvocation(BaseInvocation):
type: Literal["float"] = "float"
# Inputs
param: float = InputField(default=0.0, description="The float value")
value: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(a=self.param)
return FloatOutput(value=self.value)
@title("Float Primitive Collection")
@ -176,7 +170,7 @@ class FloatCollectionInvocation(BaseInvocation):
# Inputs
collection: list[float] = InputField(
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
@ -192,7 +186,7 @@ class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
text: str = OutputField(description="The output string")
value: str = OutputField(description="The output string")
class StringCollectionOutput(BaseInvocationOutput):
@ -201,9 +195,7 @@ class StringCollectionOutput(BaseInvocationOutput):
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
)
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
@title("String Primitive")
@ -214,10 +206,10 @@ class StringInvocation(BaseInvocation):
type: Literal["string"] = "string"
# Inputs
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
return StringOutput(value=self.value)
@title("String Primitive Collection")
@ -229,7 +221,7 @@ class StringCollectionInvocation(BaseInvocation):
# Inputs
collection: list[str] = InputField(
default=0, description="The collection of string values", ui_type=UIType.StringCollection
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
)
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
@ -262,9 +254,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
)
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
@title("Image Primitive")
@ -334,7 +324,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
type: Literal["latents_collection_output"] = "latents_collection_output"
collection: list[LatentsField] = OutputField(
default_factory=list,
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
)
@ -365,7 +354,7 @@ class LatentsCollectionInvocation(BaseInvocation):
# Inputs
collection: list[LatentsField] = InputField(
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
)
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
@ -410,9 +399,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
)
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
@title("Color Primitive")
@ -455,7 +442,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
# Outputs
collection: list[ConditioningField] = OutputField(
default_factory=list,
description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
)

View File

@ -37,7 +37,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("SDXL Main Model Loader")
@title("SDXL Main Model")
@tags("model", "sdxl")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
@ -122,7 +122,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
)
@title("SDXL Refiner Model Loader")
@title("SDXL Refiner Model")
@tags("model", "sdxl", "refiner")
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""

View File

@ -0,0 +1,8 @@
"""
Init file for InvokeAI configure package
"""
from .invokeai_config import ( # noqa F401
InvokeAIAppConfig,
get_invokeai_config,
)

View File

@ -0,0 +1,239 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
"""
Base class for the InvokeAI configuration system.
It defines a type of pydantic BaseSettings object that
is able to read and write from an omegaconf-based config file,
with overriding of settings from environment variables and/or
the command line.
"""
from __future__ import annotations
import argparse
import os
import pydoc
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path
from pydantic import BaseSettings
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
class PagingArgumentParser(argparse.ArgumentParser):
"""
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
"""
def print_help(self, file=None):
text = self.format_help()
pydoc.pager(text)
class InvokeAISettings(BaseSettings):
"""
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()})
for name, field in self.__fields__.items():
if name in cls._excluded_from_yaml():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self, name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
if cls.initconf and settings_stanza in cls.initconf
else OmegaConf.create()
)
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
for key, value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
category = field.field_info.extra.get("category", "Uncategorized")
env_name = env_prefix + "_" + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
field.default = upcase_environ[env_name.upper()]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(self, command_field: str = "type") -> str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return "Uncategorized"
@classmethod
def get_parser(cls) -> ArgumentParser:
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(self) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",
"initconf",
"version",
"from_file",
"model",
"root",
"max_cache_size",
"max_vram_cache_size",
"always_use_cpu",
"free_gpu_mem",
"xformers_enabled",
"tiled_decode",
]
class Config:
env_file_encoding = "utf-8"
arbitrary_types_allowed = True
case_sensitive = True
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
field_type = get_type_hints(cls).get(name)
default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == Union:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=int_or_float_or_str,
default=default,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs="*",
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description,
)
def int_or_float_or_str(value: str) -> Union[int, float, str]:
"""
Workaround for argparse type checking.
"""
try:
return int(value)
except Exception as e: # noqa F841
pass
try:
return float(value)
except Exception as e: # noqa F841
pass
return str(value)

View File

@ -10,37 +10,49 @@ categories returned by `invokeai --help`. The file looks like this:
[file: invokeai.yaml]
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
autoimport_dir: null
Models:
model: stable-diffusion-1.5
embeddings: true
Memory/Performance:
xformers_enabled: false
sequential_guidance: false
precision: float16
max_cache_size: 6
max_vram_cache_size: 0.5
always_use_cpu: false
free_gpu_mem: false
Features:
esrgan: true
patchmatch: true
internet_available: true
log_tokenization: false
Web Server:
host: 127.0.0.1
port: 8081
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
patchmatch: true
ignore_missing_core_models: false
Paths:
autoimport_dir: autoimport
lora_dir: null
embedding_dir: null
controlnet_dir: null
conf_path: configs/models.yaml
models_dir: models
legacy_conf_dir: configs/stable-diffusion
db_dir: databases
outdir: /home/lstein/invokeai-main/outputs
use_memory_db: false
Logging:
log_handlers:
- console
log_format: plain
log_level: info
Model Cache:
ram: 13.5
vram: 0.25
lazy_offload: true
Device:
device: auto
precision: auto
Generation:
sequential_guidance: false
attention_type: xformers
attention_slice_size: auto
force_tiled_decode: false
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can replace supersede this by providing any
@ -54,24 +66,23 @@ InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf.parse_args(argv=['--xformers_enabled'])
conf.parse_args(argv=['--log_tokenization'])
It is also possible to set a value at initialization time. However, if
you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
conf = InvokeAIAppConfig(log_tokenization=True)
conf.parse_args(argv=['--no-log_tokenization'])
conf.log_tokenization
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
conf.parse_args(argv=['--no-log_tokenization'])
conf.log_tokenization
# True
Any setting can be overwritten by setting an environment variable of
@ -93,7 +104,7 @@ Typical usage at the top level file:
# get global configuration and print its cache size
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.max_cache_size)
print(conf.ram_cache_size)
Typical usage in a backend module:
@ -101,8 +112,7 @@ Typical usage in a backend module:
# get global configuration and print its cache size value
conf = InvokeAIAppConfig.get_config()
print(conf.max_cache_size)
print(conf.ram_cache_size)
Computed properties:
@ -159,15 +169,13 @@ two configs are kept in separate sections of the config file:
"""
from __future__ import annotations
import argparse
import pydoc
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
from pydantic import Field, parse_obj_as
from typing import ClassVar, Dict, List, Literal, Union, Optional, get_type_hints
from .base import InvokeAISettings
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
@ -175,195 +183,6 @@ LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAISettings(BaseSettings):
"""
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()})
for name, field in self.__fields__.items():
if name in cls._excluded_from_yaml():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self, name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
if cls.initconf and settings_stanza in cls.initconf
else OmegaConf.create()
)
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
for key, value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
category = field.field_info.extra.get("category", "Uncategorized")
env_name = env_prefix + "_" + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
field.default = upcase_environ[env_name.upper()]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(self, command_field: str = "type") -> str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return "Uncategorized"
@classmethod
def get_parser(cls) -> ArgumentParser:
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(self) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",
"initconf",
"version",
"from_file",
"model",
"root",
]
class Config:
env_file_encoding = "utf-8"
arbitrary_types_allowed = True
case_sensitive = True
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
field_type = get_type_hints(cls).get(name)
default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs="*",
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
help=field.field_info.description,
)
def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
class InvokeAIAppConfig(InvokeAISettings):
"""
Generate images using Stable Diffusion. Use "invokeai" to launch
@ -378,6 +197,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
# WEB
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
@ -385,20 +206,14 @@ class InvokeAIAppConfig(InvokeAISettings):
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
# FEATURES
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
# PATHS
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
@ -409,16 +224,41 @@ class InvokeAIAppConfig(InvokeAISettings):
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
# LOGGING
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# CACHE
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
# DEVICE
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
# fmt: on
class Config:
@ -541,11 +381,6 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Return true if precision set to float32"""
return self.precision == "float32"
@property
def disable_xformers(self) -> bool:
"""Return true if xformers_enabled is false"""
return not self.xformers_enabled
@property
def try_patchmatch(self) -> bool:
"""Return true if patchmatch true"""
@ -561,6 +396,27 @@ class InvokeAIAppConfig(InvokeAISettings):
"""invisible watermark node is always active and disabled from Web UIe"""
return True
@property
def ram_cache_size(self) -> float:
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> float:
return self.max_vram_cache_size or self.vram
@property
def use_cpu(self) -> bool:
return self.always_use_cpu or self.device == "cpu"
@property
def disable_xformers(self) -> bool:
"""
Return true if enable_xformers is false (reversed logic)
and attention type is not set to xformers.
"""
disabled_in_config = not self.xformers_enabled
return disabled_in_config and self.attention_type != "xformers"
@staticmethod
def find_root() -> Path:
"""
@ -570,19 +426,19 @@ class InvokeAIAppConfig(InvokeAISettings):
return _find_root()
class PagingArgumentParser(argparse.ArgumentParser):
"""
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
"""
def print_help(self, file=None):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
return InvokeAIAppConfig.get_config(**kwargs)
def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root

View File

@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
description="Converts text to an image",
graph=Graph(
nodes={
"width": IntegerInvocation(id="width", a=512),
"height": IntegerInvocation(id="height", a=512),
"seed": IntegerInvocation(id="seed", a=-1),
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
@ -29,15 +29,15 @@ def create_text_to_image() -> LibraryGraph:
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="a"),
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="a"),
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="a"),
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph:
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
ExposedNodeInput(node_path="width", field="a", alias="width"),
ExposedNodeInput(node_path="height", field="a", alias="height"),
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
ExposedNodeInput(node_path="width", field="value", alias="width"),
ExposedNodeInput(node_path="height", field="value", alias="height"),
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
],
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
)

View File

@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
GIG = 1073741824
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
"""
pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass
class InvocationStatsService(InvocationStatsServiceBase):
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
class StatsContext:
"""Context manager for collecting statistics."""
invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
invocation: BaseInvocation
collector: "InvocationStatsServiceBase"
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerService
def __init__(
self,
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type,
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
errored = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
try:
current_graph_state = self.graph_execution_manager.get(graph_id)
except Exception:
errored.add(graph_id)
continue
if not current_graph_state.is_complete():
continue
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]
for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -330,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `max_cache_size` config setting
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
# use new `ram_cache_size` config setting
max_cache_size = config.ram_cache_size
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")

View File

@ -0,0 +1,56 @@
import gc
from typing import Any
import numpy as np
import torch
from PIL import Image
from invokeai.app.services.config import get_invokeai_config
from invokeai.backend.util.devices import choose_torch_device
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def load_jit_model(url_or_path, device):
model_path = url_or_path
print(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model.eval()
return model
class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
model_location = get_invokeai_config().models_path / "core/misc/lama/lama.pt"
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
mask = input_image.split()[-1]
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
return infilled_image

View File

@ -21,6 +21,7 @@ from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints, get_args, Any
from urllib import request
import npyscreen
@ -50,6 +51,7 @@ from invokeai.frontend.install.model_install import addModelsForm, process_and_e
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
MultiSelectColumns,
CenteredButtonPress,
FileBox,
set_min_terminal_size,
@ -71,6 +73,10 @@ warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
def get_literal_fields(field) -> list[Any]:
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
@ -80,7 +86,11 @@ Model_dir = "models"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ["auto", "float16", "float32"]
PRECISION_CHOICES = get_literal_fields("precision")
DEVICE_CHOICES = get_literal_fields("device")
ATTENTION_CHOICES = get_literal_fields("attention_type")
ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
@ -311,6 +321,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
Use cursor arrows to make a checkbox selection, and space to toggle.
"""
self.nextrely -= 1
for i in textwrap.wrap(label, width=window_width - 6):
self.add_widget_intelligent(
npyscreen.FixedText,
@ -337,76 +348,129 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
use_two_lines=False,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="GPU Management",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.free_gpu_mem = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Free GPU memory after each generation",
value=old_opts.free_gpu_mem,
max_width=45,
relx=5,
scroll_exit=True,
)
self.nextrely -= 1
self.xformers_enabled = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Enable xformers support",
value=old_opts.xformers_enabled,
max_width=30,
relx=50,
scroll_exit=True,
)
self.nextrely -= 1
self.always_use_cpu = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force CPU to be used on GPU systems",
value=old_opts.always_use_cpu,
relx=80,
scroll_exit=True,
)
# old settings for defaults
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
device = old_opts.device
attention_type = old_opts.attention_type
attention_slice_size = old_opts.attention_slice_size
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Floating Point Precision",
name="Image Generation Options:",
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 2
self.generation_options = self.add_widget_intelligent(
MultiSelectColumns,
columns=3,
values=GENERATION_OPT_CHOICES,
value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)],
relx=30,
max_height=2,
max_width=80,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Floating Point Precision:",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.nextrely -= 2
self.precision = self.add_widget_intelligent(
SingleSelectColumns,
columns=3,
columns=len(PRECISION_CHOICES),
name="Precision",
values=PRECISION_CHOICES,
value=PRECISION_CHOICES.index(precision),
begin_entry_at=3,
max_height=2,
relx=30,
max_width=56,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Generation Device:",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 2
self.device = self.add_widget_intelligent(
SingleSelectColumns,
columns=len(DEVICE_CHOICES),
values=DEVICE_CHOICES,
value=DEVICE_CHOICES.index(device),
begin_entry_at=3,
relx=30,
max_height=2,
max_width=60,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Attention Type:",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 2
self.attention_type = self.add_widget_intelligent(
SingleSelectColumns,
columns=len(ATTENTION_CHOICES),
values=ATTENTION_CHOICES,
value=ATTENTION_CHOICES.index(attention_type),
begin_entry_at=3,
max_height=2,
relx=30,
max_width=80,
scroll_exit=True,
)
self.nextrely += 1
self.attention_type.on_changed = self.show_hide_slice_sizes
self.attention_slice_label = self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Attention Slice Size:",
relx=5,
editable=False,
hidden=attention_type != "sliced",
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 2
self.attention_slice_size = self.add_widget_intelligent(
SingleSelectColumns,
columns=len(ATTENTION_SLICE_CHOICES),
values=ATTENTION_SLICE_CHOICES,
value=ATTENTION_SLICE_CHOICES.index(attention_slice_size),
relx=30,
hidden=attention_type != "sliced",
max_height=2,
max_width=110,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_cache_size = self.add_widget_intelligent(
self.ram = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
value=clip(old_opts.ram_cache_size, range=(3.0, MAX_RAM), step=0.5),
out_of=round(MAX_RAM),
lowest=0.0,
step=0.5,
@ -417,16 +481,16 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
self.vram = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
value=clip(old_opts.vram_cache_size, range=(0, MAX_VRAM), step=0.25),
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
@ -434,7 +498,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True,
)
else:
self.max_vram_cache_size = DummyWidgetValue.zero
self.vram_cache_size = DummyWidgetValue.zero
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@ -490,6 +554,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
when_pressed_function=self.on_ok,
)
def show_hide_slice_sizes(self, value):
show = ATTENTION_CHOICES[value[0]] == "sliced"
self.attention_slice_label.hidden = not show
self.attention_slice_size.hidden = not show
def on_ok(self):
options = self.marshall_arguments()
if self.validate_field_values(options):
@ -523,12 +592,9 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
new_opts = Namespace()
for attr in [
"ram",
"vram",
"outdir",
"free_gpu_mem",
"max_cache_size",
"max_vram_cache_size",
"xformers_enabled",
"always_use_cpu",
]:
setattr(new_opts, attr, getattr(self, attr).value)
@ -541,6 +607,12 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
new_opts.device = DEVICE_CHOICES[self.device.value[0]]
new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]]
new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]]
generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
for v in GENERATION_OPT_CHOICES:
setattr(new_opts, v, v in generation_options)
return new_opts

View File

@ -20,11 +20,36 @@
import re
from contextlib import nullcontext
from io import BytesIO
from typing import Optional, Union
from pathlib import Path
from typing import Optional, Union
import requests
import torch
from diffusers.models import (
AutoencoderKL,
ControlNetModel,
PriorTransformer,
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from diffusers.schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
from picklescan.scanner import scan_file_path
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@ -37,35 +62,8 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from diffusers.models import (
AutoencoderKL,
ControlNetModel,
PriorTransformer,
UNet2DConditionModel,
)
from diffusers.schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
from picklescan.scanner import scan_file_path
from invokeai.backend.util.logging import InvokeAILogger
from .models import BaseModelType, ModelVariantType
try:
@ -1221,9 +1219,6 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu")
@ -1662,9 +1657,6 @@ def download_controlnet_from_original_ckpt(
from omegaconf import OmegaConf
if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open
checkpoint = {}
@ -1741,7 +1733,7 @@ def convert_ckpt_to_diffusers(
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors and is_safetensors_available(),
safe_serialization=use_safetensors,
)
@ -1757,7 +1749,4 @@ def convert_controlnet_to_diffusers(
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
pipe.save_pretrained(
dump_path,
safe_serialization=is_safetensors_available(),
)
pipe.save_pretrained(dump_path, safe_serialization=True)

View File

@ -341,7 +341,8 @@ class ModelManager(object):
self.logger = logger
self.cache = ModelCache(
max_cache_size=max_cache_size,
max_vram_cache_size=self.app_config.max_vram_cache_size,
max_vram_cache_size=self.app_config.vram_cache_size,
lazy_offloading=self.app_config.lazy_offload,
execution_device=device_type,
precision=precision,
sequential_offload=sequential_offload,

View File

@ -5,7 +5,6 @@ from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
@ -175,5 +174,5 @@ def _convert_vae_ckpt_and_cache(
vae_config=config,
image_size=image_size,
)
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@ -33,7 +33,7 @@ from .diffusion import (
PostprocessingSettings,
BasicConditioningInfo,
)
from ..util import normalize_device
from ..util import normalize_device, auto_detect_slice_size
@dataclass
@ -291,6 +291,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if xformers is available, use it, otherwise use sliced attention.
"""
config = InvokeAIAppConfig.get_config()
if config.attention_type == "xformers":
self.enable_xformers_memory_efficient_attention()
return
elif config.attention_type == "sliced":
slice_size = config.attention_slice_size
if slice_size == "auto":
slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
slice_size = "auto"
self.enable_attention_slicing(slice_size=slice_size)
return
elif config.attention_type == "normal":
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
raise Exception("torch-sdp attention slicing not yet implemented")
# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
if is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()

View File

@ -11,4 +11,11 @@ from .devices import ( # noqa: F401
torch_dtype,
)
from .log import write_log # noqa: F401
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
from .util import ( # noqa: F401
ask_user,
download_with_resume,
instantiate_from_config,
url_attachment_name,
Chdir,
)
from .attention import auto_detect_slice_size # noqa: F401

View File

@ -0,0 +1,32 @@
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
"""
Utility routine used for autodetection of optimal slice size
for attention mechanism.
"""
import torch
import psutil
def auto_detect_slice_size(latents: torch.Tensor) -> str:
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
* latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if latents.device.type in {"cpu", "mps"}:
mem_free = psutil.virtual_memory().free
elif latents.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(latents.device)
else:
raise ValueError(f"unrecognized device {latents.device}")
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
return "max"
elif torch.backends.mps.is_available():
return "max"
else:
return "balanced"

View File

@ -17,13 +17,17 @@ config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if config.always_use_cpu:
if config.use_cpu: # legacy setting - force CPU
return CPU_DEVICE
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return CPU_DEVICE
elif config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return CPU_DEVICE
else:
return torch.device(config.device)
def choose_precision(device: torch.device) -> str:

View File

@ -17,8 +17,8 @@ from shutil import get_terminal_size
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
# minimum size for UIs
MIN_COLS = 130
MIN_LINES = 38
MIN_COLS = 150
MIN_LINES = 40
class WindowTooSmallException(Exception):
@ -277,6 +277,9 @@ class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
def h_cursor_line_right(self, ch):
self.h_exit_down("bye bye")
def h_cursor_line_left(self, ch):
self.h_exit_up("bye bye")
class TextBoxInner(npyscreen.MultiLineEdit):
def __init__(self, *args, **kwargs):
@ -324,55 +327,6 @@ class TextBoxInner(npyscreen.MultiLineEdit):
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
self.h_paste()
# def update(self, clear=True):
# if clear:
# self.clear()
# HEIGHT = self.height
# WIDTH = self.width
# # draw box.
# self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
# self.parent.curses_pad.hline(
# self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
# )
# self.parent.curses_pad.vline(
# self.rely, self.relx, curses.ACS_VLINE, self.height
# )
# self.parent.curses_pad.vline(
# self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
# )
# # draw corners
# self.parent.curses_pad.addch(
# self.rely,
# self.relx,
# curses.ACS_ULCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely,
# self.relx + WIDTH,
# curses.ACS_URCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely + HEIGHT,
# self.relx,
# curses.ACS_LLCORNER,
# )
# self.parent.curses_pad.addch(
# self.rely + HEIGHT,
# self.relx + WIDTH,
# curses.ACS_LRCORNER,
# )
# # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
# (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
# self.relx += 1
# self.rely += 1
# self.height -= 1
# self.width -= 1
# super().update(clear=False)
# (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)
class TextBox(npyscreen.BoxTitle):
_contained_widget = TextBoxInner

View File

@ -9,8 +9,8 @@ module.exports = {
'plugin:@typescript-eslint/recommended',
'plugin:react/recommended',
'plugin:react-hooks/recommended',
'plugin:prettier/recommended',
'plugin:react/jsx-runtime',
'prettier',
],
parser: '@typescript-eslint/parser',
parserOptions: {
@ -23,6 +23,11 @@ module.exports = {
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
root: true,
rules: {
curly: 'error',
'react/jsx-curly-brace-presence': [
'error',
{ props: 'never', children: 'never' },
],
'react-hooks/exhaustive-deps': 'error',
'no-var': 'error',
'brace-style': 'error',
@ -34,7 +39,6 @@ module.exports = {
'warn',
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
],
'prettier/prettier': ['error', { endOfLine: 'auto' }],
'@typescript-eslint/ban-ts-comment': 'warn',
'@typescript-eslint/no-explicit-any': 'warn',
'@typescript-eslint/no-empty-interface': [

View File

@ -29,12 +29,13 @@
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit",
"lint": "yarn run lint:eslint && yarn run lint:prettier && yarn run lint:tsc && yarn run lint:madge",
"lint": "concurrently -g -n eslint,prettier,tsc,madge -c cyan,green,magenta,yellow \"yarn run lint:eslint\" \"yarn run lint:prettier\" \"yarn run lint:tsc\" \"yarn run lint:madge\"",
"fix": "eslint --fix . && prettier --loglevel warn --write . && tsc --noEmit",
"lint-staged": "lint-staged",
"postinstall": "patch-package && yarn run theme",
"theme": "chakra-cli tokens src/theme/theme.ts",
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch"
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch",
"up": "yarn upgrade-interactive --latest"
},
"madge": {
"detectiveOptions": {
@ -54,7 +55,7 @@
},
"dependencies": {
"@chakra-ui/anatomy": "^2.2.0",
"@chakra-ui/icons": "^2.0.19",
"@chakra-ui/icons": "^2.1.0",
"@chakra-ui/react": "^2.8.0",
"@chakra-ui/styled-system": "^2.9.1",
"@chakra-ui/theme-tools": "^2.1.0",
@ -65,55 +66,55 @@
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1",
"@fontsource-variable/inter": "^5.0.3",
"@fontsource/inter": "^5.0.3",
"@mantine/core": "^6.0.14",
"@mantine/form": "^6.0.15",
"@mantine/hooks": "^6.0.14",
"@fontsource-variable/inter": "^5.0.8",
"@fontsource/inter": "^5.0.8",
"@mantine/core": "^6.0.19",
"@mantine/form": "^6.0.19",
"@mantine/hooks": "^6.0.19",
"@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"downshift": "^7.6.0",
"formik": "^2.4.2",
"framer-motion": "^10.12.17",
"formik": "^2.4.3",
"framer-motion": "^10.16.1",
"fuse.js": "^6.6.2",
"i18next": "^23.2.3",
"i18next": "^23.4.4",
"i18next-browser-languagedetector": "^7.0.2",
"i18next-http-backend": "^2.2.1",
"konva": "^9.2.0",
"lodash-es": "^4.17.21",
"nanostores": "^0.9.2",
"openapi-fetch": "^0.6.1",
"new-github-issue-url": "^1.0.0",
"openapi-fetch": "^0.7.4",
"overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0",
"patch-package": "^8.0.0",
"query-string": "^8.1.0",
"re-resizable": "^6.9.9",
"react": "^18.2.0",
"react-colorful": "^5.6.1",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-hotkeys-hook": "4.4.0",
"react-i18next": "^13.0.1",
"react-error-boundary": "^4.0.11",
"react-hotkeys-hook": "4.4.1",
"react-i18next": "^13.1.2",
"react-icons": "^4.10.1",
"react-konva": "^18.2.10",
"react-redux": "^8.1.1",
"react-resizable-panels": "^0.0.52",
"react-redux": "^8.1.2",
"react-resizable-panels": "^0.0.55",
"react-use": "^17.4.0",
"react-virtuoso": "^4.3.11",
"react-virtuoso": "^4.5.0",
"react-zoom-pan-pinch": "^3.0.8",
"reactflow": "^11.7.4",
"reactflow": "^11.8.3",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.7.0",
"redux-remember": "^4.0.1",
"roarr": "^7.15.1",
"serialize-error": "^11.0.1",
"socket.io-client": "^4.7.2",
"use-debounce": "^9.0.4",
"use-image": "^1.1.1",
"uuid": "^9.0.0",
"zod": "^3.21.4"
"zod": "^3.22.2",
"zod-validation-error": "^1.5.0"
},
"peerDependencies": {
"@chakra-ui/cli": "^2.4.0",
@ -126,38 +127,36 @@
"@chakra-ui/cli": "^2.4.1",
"@types/dateformat": "^5.0.0",
"@types/lodash-es": "^4.14.194",
"@types/node": "^20.3.1",
"@types/react": "^18.2.14",
"@types/node": "^20.5.1",
"@types/react": "^18.2.20",
"@types/react-dom": "^18.2.6",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.6",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.60.0",
"@typescript-eslint/parser": "^5.60.0",
"@typescript-eslint/eslint-plugin": "^6.4.1",
"@typescript-eslint/parser": "^6.4.1",
"@vitejs/plugin-react-swc": "^3.3.2",
"axios": "^1.4.0",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^8.2.0",
"eslint": "^8.43.0",
"eslint-config-prettier": "^8.8.0",
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint": "^8.47.0",
"eslint-config-prettier": "^9.0.0",
"eslint-plugin-prettier": "^5.0.0",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.2.2",
"lint-staged": "^14.0.1",
"madge": "^6.1.0",
"openapi-types": "^12.1.3",
"openapi-typescript": "^6.2.8",
"openapi-typescript-codegen": "^0.24.0",
"openapi-typescript": "^6.5.2",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.8",
"prettier": "^3.0.2",
"rollup-plugin-visualizer": "^5.9.2",
"terser": "^5.18.1",
"ts-toolbelt": "^9.6.0",
"vite": "^4.3.9",
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0",
"vite": "^4.4.9",
"vite-plugin-css-injected-by-js": "^3.3.0",
"vite-plugin-dts": "^3.5.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19"

View File

@ -19,7 +19,7 @@
"toggleAutoscroll": "Toggle autoscroll",
"toggleLogViewer": "Toggle Log Viewer",
"showGallery": "Show Gallery",
"showOptionsPanel": "Show Options Panel",
"showOptionsPanel": "Show Side Panel",
"menu": "Menu"
},
"common": {
@ -52,7 +52,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Node Editor",
"nodes": "Workflow Editor",
"batch": "Batch Manager",
"modelManager": "Model Manager",
"postprocessing": "Post Processing",
@ -95,7 +95,6 @@
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged",
"pinOptionsPanel": "Pin Options Panel",
"loading": "Loading",
"loadingInvokeAI": "Loading Invoke AI",
"random": "Random",
@ -116,7 +115,6 @@
"maintainAspectRatio": "Maintain Aspect Ratio",
"autoSwitchNewImages": "Auto-Switch to New Images",
"singleColumnLayout": "Single Column Layout",
"pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded",
"loadMore": "Load More",
"noImagesInGallery": "No Images to Display",
@ -133,6 +131,7 @@
"generalHotkeys": "General Hotkeys",
"galleryHotkeys": "Gallery Hotkeys",
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
"nodesHotkeys": "Nodes Hotkeys",
"invoke": {
"title": "Invoke",
"desc": "Generate an image"
@ -332,6 +331,10 @@
"acceptStagingImage": {
"title": "Accept Staging Image",
"desc": "Accept Current Staging Area Image"
},
"addNodes": {
"title": "Add Nodes",
"desc": "Opens the add node menu"
}
},
"modelManager": {
@ -506,12 +509,9 @@
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamPaintingHeader": "Seam Painting",
"seamSize": "Seam Size",
"seamBlur": "Seam Blur",
"seamSteps": "Seam Steps",
"seamStrength": "Seam Strength",
"seamThreshold": "Seam Threshold",
"coherencePassHeader": "Coherence Pass",
"coherenceSteps": "Coherence Pass Steps",
"coherenceStrength": "Coherence Pass Strength",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
@ -572,7 +572,7 @@
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
"resetComplete": "Web UI has been reset.",
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
@ -715,11 +715,12 @@
"swapSizes": "Swap Sizes"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveGraph": "Save Graph",
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"reloadNodeTemplates": "Reload Node Templates",
"saveWorkflow": "Save Workflow",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",

View File

@ -27,22 +27,10 @@ async function main() {
* field accepts connection input. If it does, we can make the field optional.
*/
// Check if we are generating types for an invocation
const isInvocationPath = metadata.path.match(
/^#\/components\/schemas\/\w*Invocation$/
);
const hasInvocationProperties =
schemaObject.properties &&
['id', 'is_intermediate', 'type'].every(
(prop) => prop in schemaObject.properties
);
if (isInvocationPath && hasInvocationProperties) {
if ('class' in schemaObject && schemaObject.class === 'invocation') {
// We only want to make fields optional if they are required
if (!Array.isArray(schemaObject?.required)) {
schemaObject.required = ['id', 'type'];
return;
schemaObject.required = [];
}
schemaObject.required.forEach((prop) => {
@ -61,19 +49,13 @@ async function main() {
);
}
});
schemaObject.required = [
...new Set(schemaObject.required.concat(['id', 'type'])),
];
return;
}
// if (
// 'input' in schemaObject &&
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
// ) {
// schemaObject.required = false;
// }
// Check if we are generating types for an invocation output
if ('class' in schemaObject && schemaObject.class === 'output') {
// modify output types
}
},
});
fs.writeFileSync(OUTPUT_FILE, types);

View File

@ -1,4 +1,4 @@
import { Flex, Grid, Portal } from '@chakra-ui/react';
import { Flex, Grid } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -6,17 +6,15 @@ import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import SiteHeader from 'features/system/components/SiteHeader';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n';
import { size } from 'lodash-es';
import { ReactNode, memo, useEffect } from 'react';
import { ReactNode, memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
@ -30,8 +28,13 @@ interface Props {
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger();
const logger = useLogger('system');
const dispatch = useAppDispatch();
const handleReset = useCallback(() => {
localStorage.clear();
location.reload();
return false;
}, []);
useEffect(() => {
i18n.changeLanguage(language);
@ -39,7 +42,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
useEffect(() => {
if (size(config)) {
logger.info({ namespace: 'App', config }, 'Received config');
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}
}, [dispatch, config, logger]);
@ -49,7 +52,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
}, [dispatch]);
return (
<>
<ErrorBoundary
onReset={handleReset}
FallbackComponent={AppErrorBoundaryFallback}
>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
<ImageUploader>
<Grid
@ -73,21 +79,12 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
</Flex>
</Grid>
</ImageUploader>
<GalleryDrawer />
<ParametersDrawer />
<Portal>
<FloatingParametersPanelButtons />
</Portal>
<Portal>
<FloatingGalleryButton />
</Portal>
</Grid>
<DeleteImageModal />
<ChangeBoardModal />
<Toaster />
<GlobalHotkeys />
</>
</ErrorBoundary>
);
};

View File

@ -0,0 +1,97 @@
import { Flex, Heading, Link, Text, useToast } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import newGithubIssueUrl from 'new-github-issue-url';
import { memo, useCallback, useMemo } from 'react';
import { FaCopy, FaExternalLinkAlt } from 'react-icons/fa';
import { FaArrowRotateLeft } from 'react-icons/fa6';
import { serializeError } from 'serialize-error';
type Props = {
error: Error;
resetErrorBoundary: () => void;
};
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const toast = useToast();
const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2);
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
toast({
title: 'Error Copied',
});
}, [error, toast]);
const url = useMemo(
() =>
newGithubIssueUrl({
user: 'invoke-ai',
repo: 'InvokeAI',
template: 'BUG_REPORT.yml',
title: `[bug]: ${error.name}: ${error.message}`,
}),
[error.message, error.name]
);
return (
<Flex
layerStyle="body"
sx={{
w: '100vw',
h: '100vh',
alignItems: 'center',
justifyContent: 'center',
p: 4,
}}
>
<Flex
layerStyle="first"
sx={{
flexDir: 'column',
borderRadius: 'base',
justifyContent: 'center',
gap: 8,
p: 16,
}}
>
<Heading>Something went wrong</Heading>
<Flex
layerStyle="second"
sx={{
px: 8,
py: 4,
borderRadius: 'base',
gap: 4,
justifyContent: 'space-between',
alignItems: 'center',
}}
>
<Text
sx={{
fontWeight: 600,
color: 'error.500',
_dark: { color: 'error.400' },
}}
>
{error.name}: {error.message}
</Text>
</Flex>
<Flex sx={{ gap: 4 }}>
<IAIButton
leftIcon={<FaArrowRotateLeft />}
onClick={resetErrorBoundary}
>
Reset UI
</IAIButton>
<IAIButton leftIcon={<FaCopy />} onClick={handleCopy}>
Copy Error
</IAIButton>
<Link href={url} isExternal>
<IAIButton leftIcon={<FaExternalLinkAlt />}>Create Issue</IAIButton>
</Link>
</Flex>
</Flex>
</Flex>
);
};
export default memo(AppErrorBoundaryFallback);

View File

@ -1,30 +1,21 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
ctrlKeyPressed,
metaKeyPressed,
shiftKeyPressed,
} from 'features/ui/store/hotkeysSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
toggleGalleryPanel,
toggleParametersPanel,
togglePinGalleryPanel,
togglePinParametersPanel,
} from 'features/ui/store/uiSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es';
import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
[stateSelector],
({ hotkeys, ui }) => {
({ hotkeys }) => {
const { shift, ctrl, meta } = hotkeys;
const { shouldPinParametersPanel, shouldPinGallery } = ui;
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
return { shift, ctrl, meta };
},
{
memoizeOptions: {
@ -41,9 +32,7 @@ const globalHotkeysSelector = createSelector(
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
useAppSelector(globalHotkeysSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
useHotkeys(
'*',
@ -68,34 +57,6 @@ const GlobalHotkeys: React.FC = () => {
[shift, ctrl, meta]
);
useHotkeys('o', () => {
dispatch(toggleParametersPanel());
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
dispatch(requestCanvasRescale());
}
});
useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel());
if (activeTabName === 'unifiedCanvas') {
dispatch(requestCanvasRescale());
}
});
useHotkeys('g', () => {
dispatch(toggleGalleryPanel());
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
dispatch(requestCanvasRescale());
}
});
useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel());
if (activeTabName === 'unifiedCanvas') {
dispatch(requestCanvasRescale());
}
});
useHotkeys('1', () => {
dispatch(setActiveTab('txt2img'));
});
@ -112,6 +73,10 @@ const GlobalHotkeys: React.FC = () => {
dispatch(setActiveTab('nodes'));
});
useHotkeys('5', () => {
dispatch(setActiveTab('modelManager'));
});
return null;
};

View File

@ -3,7 +3,7 @@ import {
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { ReactNode, useEffect, useMemo } from 'react';
import { ReactNode, memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
@ -46,4 +46,4 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
);
}
export default ThemeLocaleProvider;
export default memo(ThemeLocaleProvider);

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { toastQueueSelector } from 'features/system/store/systemSelectors';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
import { useCallback, useEffect } from 'react';
import { memo, useCallback, useEffect } from 'react';
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
@ -44,4 +44,4 @@ export const useAppToaster = () => {
return toaster;
};
export default Toaster;
export default memo(Toaster);

View File

@ -9,7 +9,7 @@ export const log = Roarr.child(BASE_CONTEXT);
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
type LoggerNamespace =
export type LoggerNamespace =
| 'images'
| 'models'
| 'config'

View File

@ -1,12 +1,17 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import { useEffect } from 'react';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP } from './logger';
import {
$logger,
BASE_CONTEXT,
LOG_LEVEL_MAP,
LoggerNamespace,
logger,
} from './logger';
const selector = createSelector(
systemSelector,
@ -25,7 +30,7 @@ const selector = createSelector(
}
);
export const useLogger = () => {
export const useLogger = (namespace: LoggerNamespace) => {
const { consoleLogLevel, shouldLogToConsole } = useAppSelector(selector);
// The provided Roarr browser log writer uses localStorage to config logging to console
@ -57,7 +62,7 @@ export const useLogger = () => {
$logger.set(Roarr.child(newContext));
}, []);
const logger = useStore($logger);
const log = useMemo(() => logger(namespace), [namespace]);
return logger;
return log;
};

View File

@ -1,13 +1,17 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import {
controlNetImageChanged,
controlNetProcessedImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/types';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
@ -73,22 +77,61 @@ export const addRequestedSingleImageDeletionListener = () => {
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
dispatch(resetCanvas());
}
if (imageUsage.isControlNetImage) {
dispatch(controlNetReset());
}
imageDTOs.forEach((imageDTO) => {
// reset init image if we deleted it
if (
getState().generation.initialImage?.imageName === imageDTO.image_name
) {
dispatch(clearInitialImage());
}
if (imageUsage.isInitialImage) {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
processedControlImage: null,
})
);
}
});
if (imageUsage.isNodesImage) {
dispatch(nodeEditorReset());
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(
fieldImageValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: undefined,
})
);
}
});
});
});
// Delete from server
const { requestId } = dispatch(
@ -154,17 +197,58 @@ export const addRequestedMultipleImageDeletionListener = () => {
dispatch(resetCanvas());
}
if (imagesUsage.some((i) => i.isControlNetImage)) {
dispatch(controlNetReset());
}
imageDTOs.forEach((imageDTO) => {
// reset init image if we deleted it
if (
getState().generation.initialImage?.imageName ===
imageDTO.image_name
) {
dispatch(clearInitialImage());
}
if (imagesUsage.some((i) => i.isInitialImage)) {
dispatch(clearInitialImage());
}
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
if (
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
) {
dispatch(
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlImage: null,
})
);
dispatch(
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
processedControlImage: null,
})
);
}
});
if (imagesUsage.some((i) => i.isNodesImage)) {
dispatch(nodeEditorReset());
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
forEach(node.data.inputs, (input) => {
if (
input.type === 'ImageField' &&
input.value?.image_name === imageDTO.image_name
) {
dispatch(
fieldImageValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: undefined,
})
);
}
});
});
});
} catch {
// no-op
}

View File

@ -5,6 +5,7 @@ import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
import { size } from 'lodash-es';
export const addSocketConnectedEventListener = () => {
startAppListening({
@ -18,7 +19,7 @@ export const addSocketConnectedEventListener = () => {
const { disabledTabs } = config;
if (!nodes.schema && !disabledTabs.includes('nodes')) {
if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}

View File

@ -8,8 +8,8 @@ import {
import { memo, ReactNode } from 'react';
export interface IAIButtonProps extends ButtonProps {
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
tooltip?: TooltipProps['label'];
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
isChecked?: boolean;
children: ReactNode;
}

View File

@ -34,14 +34,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
gap: 2,
borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen
? mode('base.200', 'base.750')(colorMode)
: mode('base.150', 'base.800')(colorMode),
bg: mode('base.250', 'base.750')(colorMode),
color: mode('base.900', 'base.100')(colorMode),
_hover: {
bg: isOpen
? mode('base.250', 'base.700')(colorMode)
: mode('base.200', 'base.750')(colorMode),
bg: mode('base.300', 'base.700')(colorMode),
},
fontSize: 'sm',
fontWeight: 600,
@ -90,9 +86,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box
sx={{
p: 4,
p: 2,
pt: 3,
borderBottomRadius: 'base',
bg: 'base.100',
bg: 'base.150',
_dark: {
bg: 'base.800',
},

View File

@ -100,14 +100,18 @@ const IAIDndImage = (props: IAIDndImageProps) => {
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOver) onMouseOver(e);
if (onMouseOver) {
onMouseOver(e);
}
setIsHovered(true);
},
[onMouseOver]
);
const handleMouseOut = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOut) onMouseOut(e);
if (onMouseOut) {
onMouseOut(e);
}
setIsHovered(false);
},
[onMouseOut]
@ -122,7 +126,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
? {}
: {
cursor: 'pointer',
bg: mode('base.200', 'base.800')(colorMode),
bg: mode('base.200', 'base.700')(colorMode),
_hover: {
bg: mode('base.300', 'base.650')(colorMode),
color: mode('base.500', 'base.300')(colorMode),

View File

@ -1,4 +1,5 @@
import { Box, Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaExclamation } from 'react-icons/fa';
const IAIErrorLoadingImageFallback = () => {
@ -39,4 +40,4 @@ const IAIErrorLoadingImageFallback = () => {
);
};
export default IAIErrorLoadingImageFallback;
export default memo(IAIErrorLoadingImageFallback);

View File

@ -1,4 +1,5 @@
import { Box, Skeleton } from '@chakra-ui/react';
import { memo } from 'react';
const IAIFillSkeleton = () => {
return (
@ -27,4 +28,4 @@ const IAIFillSkeleton = () => {
);
};
export default IAIFillSkeleton;
export default memo(IAIFillSkeleton);

View File

@ -9,8 +9,8 @@ import { memo } from 'react';
export type IAIIconButtonProps = IconButtonProps & {
role?: string;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
tooltip?: TooltipProps['label'];
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
isChecked?: boolean;
};

View File

@ -1,4 +1,5 @@
import { Badge, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import { ImageDTO } from 'services/api/types';
type ImageMetadataOverlayProps = {
@ -26,4 +27,4 @@ const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
);
};
export default ImageMetadataOverlay;
export default memo(ImageMetadataOverlay);

View File

@ -1,4 +1,5 @@
import { Box, Flex, Heading } from '@chakra-ui/react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
type ImageUploadOverlayProps = {
@ -87,4 +88,4 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
</Box>
);
};
export default ImageUploadOverlay;
export default memo(ImageUploadOverlay);

View File

@ -150,7 +150,9 @@ const ImageUploader = (props: ImageUploaderProps) => {
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') return;
if (e.key === ' ') {
return;
}
}}
>
<input {...getInputProps()} />

View File

@ -1,4 +1,5 @@
import { Flex, Icon } from '@chakra-ui/react';
import { memo } from 'react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
@ -19,4 +20,4 @@ const SelectImagePlaceholder = () => {
);
};
export default SelectImagePlaceholder;
export default memo(SelectImagePlaceholder);

View File

@ -1,4 +1,5 @@
import { Box } from '@chakra-ui/react';
import { memo } from 'react';
type Props = {
isSelected: boolean;
@ -18,6 +19,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
opacity: isSelected ? 1 : 0.7,
transitionProperty: 'common',
transitionDuration: '0.1s',
pointerEvents: 'none',
shadow: isSelected
? isHovered
? 'hoverSelected.light'
@ -39,4 +41,4 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
);
};
export default SelectionOverlay;
export default memo(SelectionOverlay);

View File

@ -2,71 +2,108 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
// import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { isInvocationNode } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { modelsApi } from '../../services/api/endpoints/models';
import { forEach, map } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
const readinessSelector = createSelector(
const selector = createSelector(
[stateSelector, activeTabNameSelector],
(state, activeTabName) => {
const { generation, system } = state;
const { initialImage } = generation;
const { generation, system, nodes } = state;
const { initialImage, model } = generation;
const { isProcessing, isConnected } = system;
let isReady = true;
const reasonsWhyNotReady: string[] = [];
const reasons: string[] = [];
if (activeTabName === 'img2img' && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
isReady = false;
reasonsWhyNotReady.push('System Busy');
reasons.push('System busy');
}
// Cannot generate if not connected
if (!isConnected) {
isReady = false;
reasonsWhyNotReady.push('System Disconnected');
reasons.push('System disconnected');
}
// // Cannot generate variations without valid seed weights
// if (
// shouldGenerateVariations &&
// (!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
// ) {
// isReady = false;
// reasonsWhyNotReady.push('Seed-Weights badly formatted.');
// }
if (activeTabName === 'img2img' && !initialImage) {
reasons.push('No initial image selected');
}
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push(`ControlNet ${id} has no model selected.`);
if (activeTabName === 'nodes' && nodes.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push('No nodes in graph');
}
});
// All good
return { isReady, reasonsWhyNotReady };
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = nodes.nodeTemplates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push('Missing node template');
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) =>
edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push('Missing field template');
return;
}
if (fieldTemplate.required && !field.value && !hasConnection) {
reasons.push(
`${node.data.label || nodeTemplate.title} -> ${
field.label || fieldTemplate.title
} missing input`
);
return;
}
});
});
} else {
if (!model) {
reasons.push('No model selected');
}
if (state.controlNet.isEnabled) {
map(state.controlNet.controlNets).forEach((controlNet, i) => {
if (!controlNet.isEnabled) {
return;
}
if (!controlNet.model) {
reasons.push(`ControlNet ${i + 1} has no model selected.`);
}
if (
!controlNet.controlImage ||
(!controlNet.processedControlImage &&
controlNet.processorType !== 'none')
) {
reasons.push(`ControlNet ${i + 1} has no control image`);
}
});
}
}
return { isReady: !reasons.length, isProcessing, reasons };
},
defaultSelectorOptions
);
export const useIsReadyToInvoke = () => {
const { isReady } = useAppSelector(readinessSelector);
return isReady;
const { isReady, isProcessing, reasons } = useAppSelector(selector);
return { isReady, isProcessing, reasons };
};

View File

@ -11,8 +11,14 @@ export default function useResolution():
const tabletResolutions = ['md', 'lg'];
const desktopResolutions = ['xl', '2xl'];
if (mobileResolutions.includes(breakpointValue)) return 'mobile';
if (tabletResolutions.includes(breakpointValue)) return 'tablet';
if (desktopResolutions.includes(breakpointValue)) return 'desktop';
if (mobileResolutions.includes(breakpointValue)) {
return 'mobile';
}
if (tabletResolutions.includes(breakpointValue)) {
return 'tablet';
}
if (desktopResolutions.includes(breakpointValue)) {
return 'desktop';
}
return 'unknown';
}

View File

@ -0,0 +1,2 @@
export const colorTokenToCssVar = (colorToken: string) =>
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;

View File

@ -6,7 +6,11 @@ export const dateComparator = (a: string, b: string) => {
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) return 1;
if (dateA < dateB) return -1;
if (dateA > dateB) {
return 1;
}
if (dateA < dateB) {
return -1;
}
return 0;
};

View File

@ -5,7 +5,9 @@ type Base64AndCaption = {
const openBase64ImageInTab = (images: Base64AndCaption[]) => {
const w = window.open('');
if (!w) return;
if (!w) {
return;
}
images.forEach((i) => {
const image = new Image();

View File

@ -5,6 +5,7 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { isStagingSelector } from '../store/canvasSelectors';
import { memo } from 'react';
const ClearCanvasHistoryButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
@ -28,4 +29,4 @@ const ClearCanvasHistoryButtonModal = () => {
</IAIAlertDialog>
);
};
export default ClearCanvasHistoryButtonModal;
export default memo(ClearCanvasHistoryButtonModal);

View File

@ -1,6 +1,6 @@
import { Box, chakra, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
canvasSelector,
@ -9,7 +9,7 @@ import {
import Konva from 'konva';
import { KonvaEventObject } from 'konva/lib/Node';
import { Vector2d } from 'konva/lib/types';
import { useCallback, useRef } from 'react';
import { memo, useCallback, useEffect, useRef } from 'react';
import { Layer, Stage } from 'react-konva';
import useCanvasDragMove from '../hooks/useCanvasDragMove';
import useCanvasHotkeys from '../hooks/useCanvasHotkeys';
@ -18,6 +18,7 @@ import useCanvasMouseMove from '../hooks/useCanvasMouseMove';
import useCanvasMouseOut from '../hooks/useCanvasMouseOut';
import useCanvasMouseUp from '../hooks/useCanvasMouseUp';
import useCanvasWheel from '../hooks/useCanvasZoom';
import { canvasResized } from '../store/canvasSlice';
import {
setCanvasBaseLayer,
setCanvasStage,
@ -106,7 +107,8 @@ const IAICanvas = () => {
shouldAntialias,
} = useAppSelector(selector);
useCanvasHotkeys();
const dispatch = useAppDispatch();
const containerRef = useRef<HTMLDivElement>(null);
const stageRef = useRef<Konva.Stage | null>(null);
const canvasBaseLayerRef = useRef<Konva.Layer | null>(null);
@ -137,8 +139,30 @@ const IAICanvas = () => {
const { handleDragStart, handleDragMove, handleDragEnd } =
useCanvasDragMove();
useEffect(() => {
if (!containerRef.current) {
return;
}
const resizeObserver = new ResizeObserver((entries) => {
for (const entry of entries) {
if (entry.contentBoxSize) {
const { width, height } = entry.contentRect;
dispatch(canvasResized({ width, height }));
}
}
});
resizeObserver.observe(containerRef.current);
return () => {
resizeObserver.disconnect();
};
}, [dispatch]);
return (
<Flex
id="canvas-container"
ref={containerRef}
sx={{
position: 'relative',
height: '100%',
@ -146,13 +170,18 @@ const IAICanvas = () => {
borderRadius: 'base',
}}
>
<Box sx={{ position: 'relative' }}>
<Box
sx={{
position: 'absolute',
// top: 0,
// insetInlineStart: 0,
}}
>
<ChakraStage
tabIndex={-1}
ref={canvasStageRefCallback}
sx={{
outline: 'none',
// boxShadow: '0px 0px 0px 1px var(--border-color-light)',
overflow: 'hidden',
cursor: stageCursor ? stageCursor : undefined,
canvas: {
@ -213,11 +242,11 @@ const IAICanvas = () => {
/>
</Layer>
</ChakraStage>
<IAICanvasStatusText />
<IAICanvasStagingAreaToolbar />
</Box>
<IAICanvasStatusText />
<IAICanvasStagingAreaToolbar />
</Flex>
);
};
export default IAICanvas;
export default memo(IAICanvas);

View File

@ -4,6 +4,7 @@ import { isEqual } from 'lodash-es';
import { Group, Rect } from 'react-konva';
import { canvasSelector } from '../store/canvasSelectors';
import { memo } from 'react';
const selector = createSelector(
canvasSelector,
@ -67,4 +68,4 @@ const IAICanvasBoundingBoxOverlay = () => {
);
};
export default IAICanvasBoundingBoxOverlay;
export default memo(IAICanvasBoundingBoxOverlay);

View File

@ -6,7 +6,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { isEqual, range } from 'lodash-es';
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
import { ReactNode, memo, useCallback, useLayoutEffect, useState } from 'react';
import { Group, Line as KonvaLine } from 'react-konva';
const selector = createSelector(
@ -117,4 +117,4 @@ const IAICanvasGrid = () => {
return <Group>{gridLines}</Group>;
};
export default IAICanvasGrid;
export default memo(IAICanvasGrid);

View File

@ -4,6 +4,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes';
import { $authToken } from 'services/api/client';
import { memo } from 'react';
type IAICanvasImageProps = {
canvasImage: CanvasImage;
@ -25,4 +26,4 @@ const IAICanvasImage = (props: IAICanvasImageProps) => {
return <Image x={x} y={y} image={image} listening={false} />;
};
export default IAICanvasImage;
export default memo(IAICanvasImage);

View File

@ -4,7 +4,7 @@ import { systemSelector } from 'features/system/store/systemSelectors';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash-es';
import { useEffect, useState } from 'react';
import { memo, useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva';
import { canvasSelector } from '../store/canvasSelectors';
@ -66,4 +66,4 @@ const IAICanvasIntermediateImage = (props: Props) => {
) : null;
};
export default IAICanvasIntermediateImage;
export default memo(IAICanvasIntermediateImage);

View File

@ -7,7 +7,7 @@ import { Rect } from 'react-konva';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import Konva from 'konva';
import { isNumber } from 'lodash-es';
import { useCallback, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
export const canvasMaskCompositerSelector = createSelector(
canvasSelector,
@ -125,7 +125,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
}, [offset]);
useEffect(() => {
if (fillPatternImage) return;
if (fillPatternImage) {
return;
}
const image = new Image();
image.onload = () => {
@ -135,7 +137,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
}, [fillPatternImage, maskColorString]);
useEffect(() => {
if (!fillPatternImage) return;
if (!fillPatternImage) {
return;
}
fillPatternImage.src = getColoredSVG(maskColorString);
}, [fillPatternImage, maskColorString]);
@ -151,8 +155,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
!isNumber(stageScale) ||
!isNumber(stageDimensions.width) ||
!isNumber(stageDimensions.height)
)
) {
return null;
}
return (
<Rect
@ -172,4 +177,4 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
);
};
export default IAICanvasMaskCompositer;
export default memo(IAICanvasMaskCompositer);

View File

@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
import { Group, Line } from 'react-konva';
import { isCanvasMaskLine } from '../store/canvasTypes';
import { memo } from 'react';
export const canvasLinesSelector = createSelector(
[canvasSelector],
@ -52,4 +53,4 @@ const IAICanvasLines = (props: InpaintingCanvasLinesProps) => {
);
};
export default IAICanvasLines;
export default memo(IAICanvasLines);

View File

@ -12,6 +12,7 @@ import {
isCanvasFillRect,
} from '../store/canvasTypes';
import IAICanvasImage from './IAICanvasImage';
import { memo } from 'react';
const selector = createSelector(
[canvasSelector],
@ -33,7 +34,9 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
if (!objects) return null;
if (!objects) {
return null;
}
return (
<Group name="outpainting-objects" listening={false}>
@ -101,4 +104,4 @@ const IAICanvasObjectRenderer = () => {
);
};
export default IAICanvasObjectRenderer;
export default memo(IAICanvasObjectRenderer);

View File

@ -1,89 +0,0 @@
import { Flex, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
canvasSelector,
initialCanvasImageSelector,
} from 'features/canvas/store/canvasSelectors';
import {
resizeAndScaleCanvas,
resizeCanvas,
setCanvasContainerDimensions,
setDoesCanvasNeedScaling,
} from 'features/canvas/store/canvasSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useLayoutEffect, useRef } from 'react';
const canvasResizerSelector = createSelector(
canvasSelector,
initialCanvasImageSelector,
activeTabNameSelector,
(canvas, initialCanvasImage, activeTabName) => {
const { doesCanvasNeedScaling, isCanvasInitialized } = canvas;
return {
doesCanvasNeedScaling,
activeTabName,
initialCanvasImage,
isCanvasInitialized,
};
}
);
const IAICanvasResizer = () => {
const dispatch = useAppDispatch();
const {
doesCanvasNeedScaling,
activeTabName,
initialCanvasImage,
isCanvasInitialized,
} = useAppSelector(canvasResizerSelector);
const ref = useRef<HTMLDivElement>(null);
useLayoutEffect(() => {
window.setTimeout(() => {
if (!ref.current) return;
const { clientWidth, clientHeight } = ref.current;
dispatch(
setCanvasContainerDimensions({
width: clientWidth,
height: clientHeight,
})
);
if (!isCanvasInitialized) {
dispatch(resizeAndScaleCanvas());
} else {
dispatch(resizeCanvas());
}
dispatch(setDoesCanvasNeedScaling(false));
}, 0);
}, [
dispatch,
initialCanvasImage,
doesCanvasNeedScaling,
activeTabName,
isCanvasInitialized,
]);
return (
<Flex
ref={ref}
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
gap: 4,
width: '100%',
height: '100%',
}}
>
<Spinner thickness="2px" size="xl" />
</Flex>
);
};
export default IAICanvasResizer;

View File

@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
import { Group, Rect } from 'react-konva';
import IAICanvasImage from './IAICanvasImage';
import { memo } from 'react';
const selector = createSelector(
[canvasSelector],
@ -88,4 +89,4 @@ const IAICanvasStagingArea = (props: Props) => {
);
};
export default IAICanvasStagingArea;
export default memo(IAICanvasStagingArea);

View File

@ -13,7 +13,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@ -129,7 +129,9 @@ const IAICanvasStagingAreaToolbar = () => {
currentStagingAreaImage?.imageName ?? skipToken
);
if (!currentStagingAreaImage) return null;
if (!currentStagingAreaImage) {
return null;
}
return (
<Flex
@ -138,11 +140,10 @@ const IAICanvasStagingAreaToolbar = () => {
w="100%"
align="center"
justify="center"
filter="drop-shadow(0 0.5rem 1rem rgba(0,0,0))"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<ButtonGroup isAttached>
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
<IAIIconButton
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
@ -207,4 +208,4 @@ const IAICanvasStagingAreaToolbar = () => {
);
};
export default IAICanvasStagingAreaToolbar;
export default memo(IAICanvasStagingAreaToolbar);

View File

@ -7,6 +7,7 @@ import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import roundToHundreth from '../util/roundToHundreth';
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
import { memo } from 'react';
const warningColor = 'var(--invokeai-colors-warning-500)';
@ -162,4 +163,4 @@ const IAICanvasStatusText = () => {
);
};
export default IAICanvasStatusText;
export default memo(IAICanvasStatusText);

View File

@ -10,6 +10,7 @@ import {
COLOR_PICKER_SIZE,
COLOR_PICKER_STROKE_RADIUS,
} from '../util/constants';
import { memo } from 'react';
const canvasBrushPreviewSelector = createSelector(
canvasSelector,
@ -134,7 +135,9 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
clip,
} = useAppSelector(canvasBrushPreviewSelector);
if (!shouldDrawBrushPreview) return null;
if (!shouldDrawBrushPreview) {
return null;
}
return (
<Group listening={false} {...clip} {...rest}>
@ -206,4 +209,4 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
);
};
export default IAICanvasToolPreview;
export default memo(IAICanvasToolPreview);

View File

@ -19,7 +19,7 @@ import { KonvaEventObject } from 'konva/lib/Node';
import { Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash-es';
import { useCallback, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Group, Rect, Transformer } from 'react-konva';
@ -85,7 +85,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
useState(false);
useEffect(() => {
if (!transformerRef.current || !shapeRef.current) return;
if (!transformerRef.current || !shapeRef.current) {
return;
}
transformerRef.current.nodes([shapeRef.current]);
transformerRef.current.getLayer()?.batchDraw();
}, []);
@ -133,7 +135,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
* not its width and height. We need to un-scale the width and height before
* setting the values.
*/
if (!shapeRef.current) return;
if (!shapeRef.current) {
return;
}
const rect = shapeRef.current;
@ -313,4 +317,4 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
);
};
export default IAICanvasBoundingBox;
export default memo(IAICanvasBoundingBox);

View File

@ -20,6 +20,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash-es';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@ -150,4 +151,4 @@ const IAICanvasMaskOptions = () => {
);
};
export default IAICanvasMaskOptions;
export default memo(IAICanvasMaskOptions);

View File

@ -18,7 +18,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
import { ChangeEvent, memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaWrench } from 'react-icons/fa';
@ -163,4 +163,4 @@ const IAICanvasSettingsButtonPopover = () => {
);
};
export default IAICanvasSettingsButtonPopover;
export default memo(IAICanvasSettingsButtonPopover);

View File

@ -18,6 +18,7 @@ import {
} from 'features/canvas/store/canvasSlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { clamp, isEqual } from 'lodash-es';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@ -252,4 +253,4 @@ const IAICanvasToolChooserOptions = () => {
);
};
export default IAICanvasToolChooserOptions;
export default memo(IAICanvasToolChooserOptions);

View File

@ -18,7 +18,6 @@ import {
import {
resetCanvas,
resetCanvasView,
resizeAndScaleCanvas,
setIsMaskEnabled,
setLayer,
setTool,
@ -48,6 +47,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton';
import { memo } from 'react';
export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector],
@ -166,7 +166,9 @@ const IAICanvasToolbar = () => {
const handleResetCanvasView = (shouldScaleTo1 = false) => {
const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) return;
if (!canvasBaseLayer) {
return;
}
const clientRect = canvasBaseLayer.getClientRect({
skipTransform: true,
});
@ -180,7 +182,6 @@ const IAICanvasToolbar = () => {
const handleResetCanvas = () => {
dispatch(resetCanvas());
dispatch(resizeAndScaleCanvas());
};
const handleMergeVisible = () => {
@ -309,4 +310,4 @@ const IAICanvasToolbar = () => {
);
};
export default IAICanvasToolbar;
export default memo(IAICanvasToolbar);

View File

@ -32,13 +32,17 @@ const useCanvasDrag = () => {
return {
handleDragStart: useCallback(() => {
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
return;
}
dispatch(setIsMovingStage(true));
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
handleDragMove: useCallback(
(e: KonvaEventObject<MouseEvent>) => {
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
return;
}
const newCoordinates = { x: e.target.x(), y: e.target.y() };
@ -48,7 +52,9 @@ const useCanvasDrag = () => {
),
handleDragEnd: useCallback(() => {
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
return;
}
dispatch(setIsMovingStage(false));
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
};

View File

@ -134,7 +134,9 @@ const useInpaintingCanvasHotkeys = () => {
useHotkeys(
['space'],
(e: KeyboardEvent) => {
if (e.repeat) return;
if (e.repeat) {
return;
}
canvasStage?.container().focus();

View File

@ -38,7 +38,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
return useCallback(
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
if (!stageRef.current) return;
if (!stageRef.current) {
return;
}
stageRef.current.container().focus();
@ -54,7 +56,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
if (!scaledCursorPosition) return;
if (!scaledCursorPosition) {
return;
}
e.evt.preventDefault();

View File

@ -41,11 +41,15 @@ const useCanvasMouseMove = (
const { updateColorUnderCursor } = useColorPicker();
return useCallback(() => {
if (!stageRef.current) return;
if (!stageRef.current) {
return;
}
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
if (!scaledCursorPosition) return;
if (!scaledCursorPosition) {
return;
}
dispatch(setCursorPosition(scaledCursorPosition));
@ -56,7 +60,9 @@ const useCanvasMouseMove = (
return;
}
if (!isDrawing || tool === 'move' || isStaging) return;
if (!isDrawing || tool === 'move' || isStaging) {
return;
}
didMouseMoveRef.current = true;
dispatch(

View File

@ -47,7 +47,9 @@ const useCanvasMouseUp = (
if (!didMouseMoveRef.current && isDrawing && stageRef.current) {
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
if (!scaledCursorPosition) return;
if (!scaledCursorPosition) {
return;
}
/**
* Extend the current line.

View File

@ -35,13 +35,17 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
return useCallback(
(e: KonvaEventObject<WheelEvent>) => {
// stop default scrolling
if (!stageRef.current || isMoveStageKeyHeld) return;
if (!stageRef.current || isMoveStageKeyHeld) {
return;
}
e.evt.preventDefault();
const cursorPos = stageRef.current.getPointerPosition();
if (!cursorPos) return;
if (!cursorPos) {
return;
}
const mousePointTo = {
x: (cursorPos.x - stageRef.current.x()) / stageScale,

View File

@ -16,11 +16,15 @@ const useColorPicker = () => {
return {
updateColorUnderCursor: () => {
if (!stage || !canvasBaseLayer) return;
if (!stage || !canvasBaseLayer) {
return;
}
const position = stage.getPointerPosition();
if (!position) return;
if (!position) {
return;
}
const pixelRatio = Konva.pixelRatio;

View File

@ -3,8 +3,4 @@ import { CanvasState } from './canvasTypes';
/**
* Canvas slice persist denylist
*/
export const canvasPersistDenylist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
export const canvasPersistDenylist: (keyof CanvasState)[] = ['cursorPosition'];

View File

@ -5,10 +5,6 @@ import {
roundToMultiple,
} from 'common/util/roundDownToMultiple';
import { setAspectRatio } from 'features/parameters/store/generationSlice';
import {
setActiveTab,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { IRect, Vector2d } from 'konva/lib/types';
import { clamp, cloneDeep } from 'lodash-es';
import { RgbaColor } from 'react-colorful';
@ -50,12 +46,9 @@ export const initialCanvasState: CanvasState = {
boundingBoxScaleMethod: 'none',
brushColor: { r: 90, g: 90, b: 255, a: 1 },
brushSize: 50,
canvasContainerDimensions: { width: 0, height: 0 },
colorPickerColor: { r: 90, g: 90, b: 255, a: 1 },
cursorPosition: null,
doesCanvasNeedScaling: false,
futureLayerStates: [],
isCanvasInitialized: false,
isDrawing: false,
isMaskEnabled: true,
isMouseOverBoundingBox: false,
@ -208,7 +201,6 @@ export const canvasSlice = createSlice({
};
state.futureLayerStates = [];
state.isCanvasInitialized = false;
const newScale = calculateScale(
stageDimensions.width,
stageDimensions.height,
@ -228,7 +220,6 @@ export const canvasSlice = createSlice({
);
state.stageScale = newScale;
state.stageCoordinates = newCoordinates;
state.doesCanvasNeedScaling = true;
},
setBoundingBoxDimensions: (state, action: PayloadAction<Dimensions>) => {
const newDimensions = roundDimensionsTo64(action.payload);
@ -258,9 +249,6 @@ export const canvasSlice = createSlice({
setBoundingBoxPreviewFill: (state, action: PayloadAction<RgbaColor>) => {
state.boundingBoxPreviewFill = action.payload;
},
setDoesCanvasNeedScaling: (state, action: PayloadAction<boolean>) => {
state.doesCanvasNeedScaling = action.payload;
},
setStageScale: (state, action: PayloadAction<number>) => {
state.stageScale = action.payload;
},
@ -397,7 +385,9 @@ export const canvasSlice = createSlice({
const { tool, layer, brushColor, brushSize, shouldRestrictStrokesToBox } =
state;
if (tool === 'move' || tool === 'colorPicker') return;
if (tool === 'move' || tool === 'colorPicker') {
return;
}
const newStrokeWidth = brushSize / 2;
@ -434,14 +424,18 @@ export const canvasSlice = createSlice({
addPointToCurrentLine: (state, action: PayloadAction<number[]>) => {
const lastLine = state.layerState.objects.findLast(isCanvasAnyLine);
if (!lastLine) return;
if (!lastLine) {
return;
}
lastLine.points.push(...action.payload);
},
undo: (state) => {
const targetState = state.pastLayerStates.pop();
if (!targetState) return;
if (!targetState) {
return;
}
state.futureLayerStates.unshift(cloneDeep(state.layerState));
@ -454,7 +448,9 @@ export const canvasSlice = createSlice({
redo: (state) => {
const targetState = state.futureLayerStates.shift();
if (!targetState) return;
if (!targetState) {
return;
}
state.pastLayerStates.push(cloneDeep(state.layerState));
@ -485,97 +481,14 @@ export const canvasSlice = createSlice({
state.layerState = initialLayerState;
state.futureLayerStates = [];
},
setCanvasContainerDimensions: (
canvasResized: (
state,
action: PayloadAction<Dimensions>
action: PayloadAction<{ width: number; height: number }>
) => {
state.canvasContainerDimensions = action.payload;
},
resizeAndScaleCanvas: (state) => {
const { width: containerWidth, height: containerHeight } =
state.canvasContainerDimensions;
const initialCanvasImage =
state.layerState.objects.find(isCanvasBaseImage);
const { width, height } = action.payload;
const newStageDimensions = {
width: Math.floor(containerWidth),
height: Math.floor(containerHeight),
};
if (!initialCanvasImage) {
const newScale = calculateScale(
newStageDimensions.width,
newStageDimensions.height,
512,
512,
STAGE_PADDING_PERCENTAGE
);
const newCoordinates = calculateCoordinates(
newStageDimensions.width,
newStageDimensions.height,
0,
0,
512,
512,
newScale
);
const newBoundingBoxDimensions = { width: 512, height: 512 };
state.stageScale = newScale;
state.stageCoordinates = newCoordinates;
state.stageDimensions = newStageDimensions;
state.boundingBoxCoordinates = { x: 0, y: 0 };
state.boundingBoxDimensions = newBoundingBoxDimensions;
if (state.boundingBoxScaleMethod === 'auto') {
const scaledDimensions = getScaledBoundingBoxDimensions(
newBoundingBoxDimensions
);
state.scaledBoundingBoxDimensions = scaledDimensions;
}
return;
}
const { width: imageWidth, height: imageHeight } = initialCanvasImage;
const padding = 0.95;
const newScale = calculateScale(
containerWidth,
containerHeight,
imageWidth,
imageHeight,
padding
);
const newCoordinates = calculateCoordinates(
newStageDimensions.width,
newStageDimensions.height,
0,
0,
imageWidth,
imageHeight,
newScale
);
state.minimumStageScale = newScale;
state.stageScale = newScale;
state.stageCoordinates = floorCoordinates(newCoordinates);
state.stageDimensions = newStageDimensions;
state.isCanvasInitialized = true;
},
resizeCanvas: (state) => {
const { width: containerWidth, height: containerHeight } =
state.canvasContainerDimensions;
const newStageDimensions = {
width: Math.floor(containerWidth),
height: Math.floor(containerHeight),
width: Math.floor(width),
height: Math.floor(height),
};
state.stageDimensions = newStageDimensions;
@ -868,14 +781,6 @@ export const canvasSlice = createSlice({
state.layerState.stagingArea = initialLayerState.stagingArea;
}
});
builder.addCase(setShouldUseCanvasBetaLayout, (state) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(setActiveTab, (state) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload;
if (ratio) {
@ -907,8 +812,6 @@ export const {
resetCanvas,
resetCanvasInteractionState,
resetCanvasView,
resizeAndScaleCanvas,
resizeCanvas,
setBoundingBoxCoordinates,
setBoundingBoxDimensions,
setBoundingBoxPreviewFill,
@ -916,10 +819,8 @@ export const {
flipBoundingBoxAxes,
setBrushColor,
setBrushSize,
setCanvasContainerDimensions,
setColorPickerColor,
setCursorPosition,
setDoesCanvasNeedScaling,
setInitialCanvasImage,
setIsDrawing,
setIsMaskEnabled,
@ -958,6 +859,7 @@ export const {
stagingAreaInitialized,
canvasSessionIdChanged,
setShouldAntialias,
canvasResized,
} = canvasSlice.actions;
export default canvasSlice.reducer;

View File

@ -126,12 +126,9 @@ export interface CanvasState {
boundingBoxScaleMethod: BoundingBoxScale;
brushColor: RgbaColor;
brushSize: number;
canvasContainerDimensions: Dimensions;
colorPickerColor: RgbaColor;
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;
isMouseOverBoundingBox: boolean;

View File

@ -1,16 +0,0 @@
import { AppDispatch, AppGetState } from 'app/store/store';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { debounce } from 'lodash-es';
import { setDoesCanvasNeedScaling } from '../canvasSlice';
const debouncedCanvasScale = debounce((dispatch: AppDispatch) => {
dispatch(setDoesCanvasNeedScaling(true));
}, 300);
export const requestCanvasRescale =
() => (dispatch: AppDispatch, getState: AppGetState) => {
const activeTabName = activeTabNameSelector(getState());
if (activeTabName === 'unifiedCanvas') {
debouncedCanvasScale(dispatch);
}
};

View File

@ -5,7 +5,9 @@ const getScaledCursorPosition = (stage: Stage) => {
const stageTransform = stage.getAbsoluteTransform().copy();
if (!pointerPosition || !stageTransform) return;
if (!pointerPosition || !stageTransform) {
return;
}
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);

View File

@ -80,19 +80,19 @@ const ControlNet = (props: ControlNetProps) => {
sx={{
flexDir: 'column',
gap: 3,
p: 3,
p: 2,
borderRadius: 'base',
position: 'relative',
bg: 'base.200',
bg: 'base.250',
_dark: {
bg: 'base.850',
bg: 'base.750',
},
}}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch
tooltip={'Toggle this ControlNet'}
aria-label={'Toggle this ControlNet'}
tooltip="Toggle this ControlNet"
aria-label="Toggle this ControlNet"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
@ -194,7 +194,7 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={controlNet} height={28} />
<ControlNetImagePreview controlNet={controlNet} isSmall />
</Flex>
)}
</Flex>
@ -207,7 +207,7 @@ const ControlNet = (props: ControlNetProps) => {
{isExpanded && (
<>
<ControlNetImagePreview controlNet={controlNet} height="392px" />
<ControlNetImagePreview controlNet={controlNet} />
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
<ControlNetProcessorComponent controlNet={controlNet} />
</>

View File

@ -1,14 +1,14 @@
import { Box, Flex, Spinner, SystemStyleObject } from '@chakra-ui/react';
import { Box, Flex, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@ -21,7 +21,7 @@ import {
type Props = {
controlNet: ControlNetConfig;
height: SystemStyleObject['h'];
isSmall?: boolean;
};
const selector = createSelector(
@ -36,15 +36,14 @@ const selector = createSelector(
defaultSelectorOptions
);
const ControlNetImagePreview = (props: Props) => {
const { height } = props;
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const {
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
isEnabled,
controlNetId,
} = props.controlNet;
} = controlNet;
const dispatch = useAppDispatch();
@ -109,7 +108,7 @@ const ControlNetImagePreview = (props: Props) => {
sx={{
position: 'relative',
w: 'full',
h: height,
h: isSmall ? 28 : 366, // magic no touch
alignItems: 'center',
justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',

View File

@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
const selector = createSelector(
stateSelector,
@ -36,4 +36,4 @@ const ParamControlNetFeatureToggle = () => {
);
};
export default ParamControlNetFeatureToggle;
export default memo(ParamControlNetFeatureToggle);

View File

@ -23,7 +23,7 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return (
<IAISlider
isDisabled={!isEnabled}
label={'Weight'}
label="Weight"
value={weight}
onChange={handleWeightChanged}
min={0}

View File

@ -8,6 +8,7 @@ import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import { memo } from 'react';
const selector = createSelector(
stateSelector,
@ -40,4 +41,4 @@ const ParamDynamicPromptsCollapse = () => {
);
};
export default ParamDynamicPromptsCollapse;
export default memo(ParamDynamicPromptsCollapse);

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { combinatorialToggled } from '../store/dynamicPromptsSlice';
const selector = createSelector(
@ -34,4 +34,4 @@ const ParamDynamicPromptsCombinatorial = () => {
);
};
export default ParamDynamicPromptsCombinatorial;
export default memo(ParamDynamicPromptsCombinatorial);

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import { isEnabledToggled } from '../store/dynamicPromptsSlice';
const selector = createSelector(
@ -33,4 +33,4 @@ const ParamDynamicPromptsToggle = () => {
);
};
export default ParamDynamicPromptsToggle;
export default memo(ParamDynamicPromptsToggle);

View File

@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import {
maxPromptsChanged,
maxPromptsReset,
@ -60,4 +60,4 @@ const ParamDynamicPromptsMaxPrompts = () => {
);
};
export default ParamDynamicPromptsMaxPrompts;
export default memo(ParamDynamicPromptsMaxPrompts);

View File

@ -13,7 +13,7 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
@ -118,7 +118,7 @@ const ParamEmbeddingPopover = (props: Props) => {
<IAIMantineSearchableSelect
inputRef={inputRef}
autoFocus
placeholder={'Add Embedding'}
placeholder="Add Embedding"
value={null}
data={data}
nothingFound="No matching Embeddings"
@ -140,4 +140,4 @@ const ParamEmbeddingPopover = (props: Props) => {
);
};
export default ParamEmbeddingPopover;
export default memo(ParamEmbeddingPopover);

View File

@ -1,4 +1,5 @@
import { Badge, Flex } from '@chakra-ui/react';
import { memo } from 'react';
const AutoAddIcon = () => {
return (
@ -20,4 +21,4 @@ const AutoAddIcon = () => {
);
};
export default AutoAddIcon;
export default memo(AutoAddIcon);

Some files were not shown because too many files have changed in this diff Show More