diff --git a/.gitignore b/.gitignore index cc000de20e..44a0864b5b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,23 +1,8 @@ -# ignore default image save location and model symbolic link .idea/ -embeddings/ -outputs/ -models/ldm/stable-diffusion-v1/model.ckpt -**/restoration/codeformer/weights - -# ignore user models config -configs/models.user.yaml -config/models.user.yml -invokeai.init -.version -.last_model # ignore the Anaconda/Miniconda installer used while building Docker image anaconda.sh -# ignore a directory which serves as a place for initial images -inputs/ - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -189,39 +174,17 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -src **/__pycache__/ -outputs -# Logs and associated folders -# created from generated embeddings. -logs -testtube -checkpoints # If it's a Mac .DS_Store -invokeai/frontend/yarn.lock -invokeai/frontend/node_modules - # Let the frontend manage its own gitignore !invokeai/frontend/web/* # Scratch folder .scratch/ .vscode/ -gfpgan/ -models/ldm/stable-diffusion-v1/*.sha256 - - -# GFPGAN model files -gfpgan/ - -# config file (will be created by installer) -configs/models.yaml - -# ignore initfile -.invokeai # ignore environment.yml and requirements.txt # these are links to the real files in environments-and-requirements diff --git a/docs/features/CONFIGURATION.md b/docs/features/CONFIGURATION.md index 09e6143e95..6920d3d97f 100644 --- a/docs/features/CONFIGURATION.md +++ b/docs/features/CONFIGURATION.md @@ -175,22 +175,27 @@ These configuration settings allow you to enable and disable various InvokeAI fe | `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet | | `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected | | `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting | -| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) | -### Memory/Performance +### Generation These options tune InvokeAI's memory and performance characteristics. -| Setting | Default Value | Description | -|----------|----------------|--------------| -| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available | -| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties | -| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly | -| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. | -| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system | -| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss | -| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed| -| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit | +| Setting | Default Value | Description | +|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss | +| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` | +| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8| +| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance | + +### Device + +These options configure the generation execution device. + +| Setting | Default Value | Description | +|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. | +| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system | + ### Paths diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 9d9e47d2ef..b69a0b9a03 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -55,7 +55,7 @@ async def get_version() -> AppVersion: @app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig) async def get_config() -> AppConfig: - infill_methods = ["tile"] + infill_methods = ["tile", "lama"] if PatchMatch.patchmatch_available(): infill_methods.append("patchmatch") diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 0a31116878..b34000dc04 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -122,6 +122,7 @@ def custom_openapi(): output_schemas = schema(output_types, ref_prefix="#/components/schemas/") for schema_key, output_schema in output_schemas["definitions"].items(): + output_schema["class"] = "output" openapi_schema["components"]["schemas"][schema_key] = output_schema # TODO: note that we assume the schema_key here is the TYPE.__name__ @@ -130,8 +131,8 @@ def custom_openapi(): # Add Node Editor UI helper schemas ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/") - for schema_key, output_schema in ui_config_schemas["definitions"].items(): - openapi_schema["components"]["schemas"][schema_key] = output_schema + for schema_key, ui_config_schema in ui_config_schemas["definitions"].items(): + openapi_schema["components"]["schemas"][schema_key] = ui_config_schema # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: @@ -140,8 +141,8 @@ def custom_openapi(): output_type_title = output_type_titles[output_type.__name__] invoker_schema = openapi_schema["components"]["schemas"][invoker_name] outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} - invoker_schema["output"] = outputs_ref + invoker_schema["class"] = "invocation" from invokeai.backend.model_management.models import get_model_config_enums diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 6094c868d9..cbf5d1bfae 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -71,6 +71,9 @@ class FieldDescriptions: safe_mode = "Whether or not to use safe mode" scribble_mode = "Whether or not to use scribble mode" scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) num_1 = "The first number" num_2 = "The second number" mask = "The mask to use for the operation" @@ -140,6 +143,7 @@ class UIType(str, Enum): # region Misc FilePath = "FilePath" Enum = "enum" + Scheduler = "Scheduler" # endregion @@ -166,6 +170,7 @@ class _InputField(BaseModel): ui_hidden: bool ui_type: Optional[UIType] ui_component: Optional[UIComponent] + ui_order: Optional[int] class _OutputField(BaseModel): @@ -178,6 +183,7 @@ class _OutputField(BaseModel): ui_hidden: bool ui_type: Optional[UIType] + ui_order: Optional[int] def InputField( @@ -211,6 +217,7 @@ def InputField( ui_type: Optional[UIType] = None, ui_component: Optional[UIComponent] = None, ui_hidden: bool = False, + ui_order: Optional[int] = None, **kwargs: Any, ) -> Any: """ @@ -269,6 +276,7 @@ def InputField( ui_type=ui_type, ui_component=ui_component, ui_hidden=ui_hidden, + ui_order=ui_order, **kwargs, ) @@ -302,6 +310,7 @@ def OutputField( repr: bool = True, ui_type: Optional[UIType] = None, ui_hidden: bool = False, + ui_order: Optional[int] = None, **kwargs: Any, ) -> Any: """ @@ -348,6 +357,7 @@ def OutputField( repr=repr, ui_type=ui_type, ui_hidden=ui_hidden, + ui_order=ui_order, **kwargs, ) @@ -376,7 +386,7 @@ class BaseInvocationOutput(BaseModel): """Base class for all invocation outputs""" # All outputs must include a type name like this: - # type: Literal['your_output_name'] + # type: Literal['your_output_name'] # noqa f821 @classmethod def get_all_subclasses_tuple(cls): @@ -389,6 +399,13 @@ class BaseInvocationOutput(BaseModel): toprocess.extend(next_subclasses) return tuple(subclasses) + class Config: + @staticmethod + def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + if "required" not in schema or not isinstance(schema["required"], list): + schema["required"] = list() + schema["required"].extend(["type"]) + class RequiredConnectionException(Exception): """Raised when an field which requires a connection did not receive a value.""" @@ -410,7 +427,7 @@ class BaseInvocation(ABC, BaseModel): """ # All invocations must include a type name like this: - # type: Literal['your_output_name'] + # type: Literal['your_output_name'] # noqa f821 @classmethod def get_all_subclasses(cls): @@ -449,6 +466,9 @@ class BaseInvocation(ABC, BaseModel): schema["title"] = uiconfig.title if uiconfig and hasattr(uiconfig, "tags"): schema["tags"] = uiconfig.tags + if "required" not in schema or not isinstance(schema["required"], list): + schema["required"] = list() + schema["required"].extend(["type", "id"]) @abstractmethod def invoke(self, context: InvocationContext) -> BaseInvocationOutput: @@ -485,7 +505,7 @@ class BaseInvocation(ABC, BaseModel): raise MissingInputException(self.__fields__["type"].default, field_name) return self.invoke(context) - id: str = InputField(description="The id of this node. Must be unique among all nodes.") + id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = InputField( default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct ) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 7c5ddf815d..8a4cadc139 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -232,7 +232,7 @@ class SDXLPromptInvocationBase: dtype_for_device_getter=torch_dtype, truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip - requires_pooled=True, + requires_pooled=get_pooled, ) conjunction = Compel.parse_prompt_string(prompt) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f4a1648196..36157e195a 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -8,7 +8,7 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps from invokeai.app.invocations.metadata import CoreMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker @@ -41,6 +41,39 @@ class ShowImageInvocation(BaseInvocation): ) +@title("Blank Image") +@tags("image") +class BlankImageInvocation(BaseInvocation): + """Creates a blank image and forwards it to the pipeline""" + + # Metadata + type: Literal["blank_image"] = "blank_image" + + # Inputs + width: int = InputField(default=512, description="The width of the image") + height: int = InputField(default=512, description="The height of the image") + mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") + color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) + + image_dto = context.services.images.create( + image=image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + + @title("Crop Image") @tags("image", "crop") class ImageCropInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index fea418567b..78b641b210 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -1,23 +1,25 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import math from typing import Literal, Optional, get_args import numpy as np -import math from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField +from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch from ..models.image import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags +from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title def infill_methods() -> list[str]: methods = [ "tile", "solid", + "lama", ] if PatchMatch.patchmatch_available(): methods.insert(0, "patchmatch") @@ -28,6 +30,11 @@ INFILL_METHODS = Literal[tuple(infill_methods())] DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" +def infill_lama(im: Image.Image) -> Image.Image: + lama = LaMA() + return lama(im) + + def infill_patchmatch(im: Image.Image) -> Image.Image: if im.mode != "RGBA": return im @@ -90,7 +97,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return im # Find all invalid tiles and replace with a random valid tile - replace_count = (tiles_mask is False).sum() + replace_count = (tiles_mask == False).sum() # noqa: E712 rng = np.random.default_rng(seed=seed) tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :] @@ -218,3 +225,34 @@ class InfillPatchMatchInvocation(BaseInvocation): width=image_dto.width, height=image_dto.height, ) + + +@title("LaMa Infill") +@tags("image", "inpaint") +class LaMaInfillInvocation(BaseInvocation): + """Infills transparent areas of an image using the LaMa model""" + + type: Literal["infill_lama"] = "infill_lama" + + # Inputs + image: ImageField = InputField(description="The image to infill") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image(self.image.image_name) + + infilled = infill_lama(image.copy()) + + image_dto = context.services.images.create( + image=infilled, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index e12cc18f42..314301663b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,6 +4,7 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops +import numpy as np import torch import torchvision.transforms as T from diffusers.image_processor import VaeImageProcessor @@ -106,24 +107,28 @@ class DenoiseLatentsInvocation(BaseInvocation): # Inputs positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, input=Input.Connection + description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, input=Input.Connection + description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 ) - noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection) + noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) cfg_scale: Union[float, List[float]] = InputField( - default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float + default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale" ) denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler) - unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, description=FieldDescriptions.control, input=Input.Connection + scheduler: SAMPLER_NAME_VALUES = InputField( + default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler + ) + unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2) + control: Union[ControlField, list[ControlField]] = InputField( + default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5 + ) + latents: Optional[LatentsField] = InputField( + description=FieldDescriptions.latents, input=Input.Connection, ui_order=4 ) - latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) mask: Optional[ImageField] = InputField( default=None, description=FieldDescriptions.mask, @@ -453,7 +458,7 @@ class DenoiseLatentsInvocation(BaseInvocation): @title("Latents to Image") -@tags("latents", "image", "vae") +@tags("latents", "image", "vae", "l2i") class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" @@ -641,7 +646,7 @@ class ScaleLatentsInvocation(BaseInvocation): @title("Image to Latents") -@tags("latents", "image", "vae") +@tags("latents", "image", "vae", "i2l") class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -720,3 +725,81 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents, seed=None) + + +@title("Blend Latents") +@tags("latents", "blend") +class BlendLatentsInvocation(BaseInvocation): + """Blend two latents using a given alpha. Latents must have same size.""" + + type: Literal["lblend"] = "lblend" + + # Inputs + latents_a: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + latents_b: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents_a = context.services.latents.get(self.latents_a.latents_name) + latents_b = context.services.latents.get(self.latents_b.latents_name) + + if latents_a.shape != latents_b.shape: + raise "Latents to blend must be the same size." + + # TODO: + device = choose_torch_device() + + def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + """ + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + """ + inputs_are_torch = False + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + v0 = v0.detach().cpu().numpy() + if not isinstance(v1, np.ndarray): + inputs_are_torch = True + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(device) + + return v2 + + # blend + blended_latents = slerp(self.alpha, latents_a, latents_b) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + blended_latents = blended_latents.to("cpu") + torch.cuda.empty_cache() + + name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) + context.services.latents.save(name, blended_latents) + return build_latents_output(latents_name=name, latents=blended_latents) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 13e3d92f52..80cdc09221 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -21,7 +21,7 @@ class AddInvocation(BaseInvocation): b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=self.a + self.b) + return IntegerOutput(value=self.a + self.b) @title("Subtract Integers") @@ -36,7 +36,7 @@ class SubtractInvocation(BaseInvocation): b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=self.a - self.b) + return IntegerOutput(value=self.a - self.b) @title("Multiply Integers") @@ -51,7 +51,7 @@ class MultiplyInvocation(BaseInvocation): b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=self.a * self.b) + return IntegerOutput(value=self.a * self.b) @title("Divide Integers") @@ -66,7 +66,7 @@ class DivideInvocation(BaseInvocation): b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=int(self.a / self.b)) + return IntegerOutput(value=int(self.a / self.b)) @title("Random Integer") @@ -81,4 +81,4 @@ class RandomIntInvocation(BaseInvocation): high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=np.random.randint(self.low, self.high)) + return IntegerOutput(value=np.random.randint(self.low, self.high)) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index cecca78651..3cae4b3383 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -72,7 +72,7 @@ class LoRAModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") -@title("Main Model Loader") +@title("Main Model") @tags("model") class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -179,7 +179,7 @@ class LoraLoaderOutput(BaseInvocationOutput): # fmt: on -@title("LoRA Loader") +@title("LoRA") @tags("lora", "model") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -257,7 +257,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): # fmt: on -@title("SDXL LoRA Loader") +@title("SDXL LoRA") @tags("sdxl", "lora", "model") class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -356,7 +356,7 @@ class VaeLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@title("VAE Loader") +@title("VAE") @tags("vae", "model") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 3e65c1e55d..b16694357b 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): ui_type=UIType.Float, ) scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct + default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler ) precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) unet: UNetField = InputField( @@ -406,7 +406,7 @@ class OnnxModelField(BaseModel): model_type: ModelType = Field(description="Model Type") -@title("ONNX Model Loader") +@title("ONNX Main Model") @tags("onnx", "model") class OnnxModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index f32cb14f3a..607423e570 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -2,8 +2,8 @@ from typing import Literal, Optional, Tuple -from pydantic import BaseModel, Field import torch +from pydantic import BaseModel, Field from .baseinvocation import ( BaseInvocation, @@ -33,7 +33,7 @@ class BooleanOutput(BaseInvocationOutput): """Base class for nodes that output a single boolean""" type: Literal["boolean_output"] = "boolean_output" - a: bool = OutputField(description="The output boolean") + value: bool = OutputField(description="The output boolean") class BooleanCollectionOutput(BaseInvocationOutput): @@ -42,9 +42,7 @@ class BooleanCollectionOutput(BaseInvocationOutput): type: Literal["boolean_collection_output"] = "boolean_collection_output" # Outputs - collection: list[bool] = OutputField( - default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection - ) + collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection) @title("Boolean Primitive") @@ -55,10 +53,10 @@ class BooleanInvocation(BaseInvocation): type: Literal["boolean"] = "boolean" # Inputs - a: bool = InputField(default=False, description="The boolean value") + value: bool = InputField(default=False, description="The boolean value") def invoke(self, context: InvocationContext) -> BooleanOutput: - return BooleanOutput(a=self.a) + return BooleanOutput(value=self.value) @title("Boolean Primitive Collection") @@ -70,7 +68,7 @@ class BooleanCollectionInvocation(BaseInvocation): # Inputs collection: list[bool] = InputField( - default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection + default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection ) def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: @@ -86,7 +84,7 @@ class IntegerOutput(BaseInvocationOutput): """Base class for nodes that output a single integer""" type: Literal["integer_output"] = "integer_output" - a: int = OutputField(description="The output integer") + value: int = OutputField(description="The output integer") class IntegerCollectionOutput(BaseInvocationOutput): @@ -95,9 +93,7 @@ class IntegerCollectionOutput(BaseInvocationOutput): type: Literal["integer_collection_output"] = "integer_collection_output" # Outputs - collection: list[int] = OutputField( - default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection - ) + collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection) @title("Integer Primitive") @@ -108,10 +104,10 @@ class IntegerInvocation(BaseInvocation): type: Literal["integer"] = "integer" # Inputs - a: int = InputField(default=0, description="The integer value") + value: int = InputField(default=0, description="The integer value") def invoke(self, context: InvocationContext) -> IntegerOutput: - return IntegerOutput(a=self.a) + return IntegerOutput(value=self.value) @title("Integer Primitive Collection") @@ -139,7 +135,7 @@ class FloatOutput(BaseInvocationOutput): """Base class for nodes that output a single float""" type: Literal["float_output"] = "float_output" - a: float = OutputField(description="The output float") + value: float = OutputField(description="The output float") class FloatCollectionOutput(BaseInvocationOutput): @@ -148,9 +144,7 @@ class FloatCollectionOutput(BaseInvocationOutput): type: Literal["float_collection_output"] = "float_collection_output" # Outputs - collection: list[float] = OutputField( - default_factory=list, description="The float collection", ui_type=UIType.FloatCollection - ) + collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection) @title("Float Primitive") @@ -161,10 +155,10 @@ class FloatInvocation(BaseInvocation): type: Literal["float"] = "float" # Inputs - param: float = InputField(default=0.0, description="The float value") + value: float = InputField(default=0.0, description="The float value") def invoke(self, context: InvocationContext) -> FloatOutput: - return FloatOutput(a=self.param) + return FloatOutput(value=self.value) @title("Float Primitive Collection") @@ -176,7 +170,7 @@ class FloatCollectionInvocation(BaseInvocation): # Inputs collection: list[float] = InputField( - default=0, description="The collection of float values", ui_type=UIType.FloatCollection + default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection ) def invoke(self, context: InvocationContext) -> FloatCollectionOutput: @@ -192,7 +186,7 @@ class StringOutput(BaseInvocationOutput): """Base class for nodes that output a single string""" type: Literal["string_output"] = "string_output" - text: str = OutputField(description="The output string") + value: str = OutputField(description="The output string") class StringCollectionOutput(BaseInvocationOutput): @@ -201,9 +195,7 @@ class StringCollectionOutput(BaseInvocationOutput): type: Literal["string_collection_output"] = "string_collection_output" # Outputs - collection: list[str] = OutputField( - default_factory=list, description="The output strings", ui_type=UIType.StringCollection - ) + collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection) @title("String Primitive") @@ -214,10 +206,10 @@ class StringInvocation(BaseInvocation): type: Literal["string"] = "string" # Inputs - text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) + value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) def invoke(self, context: InvocationContext) -> StringOutput: - return StringOutput(text=self.text) + return StringOutput(value=self.value) @title("String Primitive Collection") @@ -229,7 +221,7 @@ class StringCollectionInvocation(BaseInvocation): # Inputs collection: list[str] = InputField( - default=0, description="The collection of string values", ui_type=UIType.StringCollection + default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection ) def invoke(self, context: InvocationContext) -> StringCollectionOutput: @@ -262,9 +254,7 @@ class ImageCollectionOutput(BaseInvocationOutput): type: Literal["image_collection_output"] = "image_collection_output" # Outputs - collection: list[ImageField] = OutputField( - default_factory=list, description="The output images", ui_type=UIType.ImageCollection - ) + collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection) @title("Image Primitive") @@ -334,7 +324,6 @@ class LatentsCollectionOutput(BaseInvocationOutput): type: Literal["latents_collection_output"] = "latents_collection_output" collection: list[LatentsField] = OutputField( - default_factory=list, description=FieldDescriptions.latents, ui_type=UIType.LatentsCollection, ) @@ -365,7 +354,7 @@ class LatentsCollectionInvocation(BaseInvocation): # Inputs collection: list[LatentsField] = InputField( - default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection + description="The collection of latents tensors", ui_type=UIType.LatentsCollection ) def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: @@ -410,9 +399,7 @@ class ColorCollectionOutput(BaseInvocationOutput): type: Literal["color_collection_output"] = "color_collection_output" # Outputs - collection: list[ColorField] = OutputField( - default_factory=list, description="The output colors", ui_type=UIType.ColorCollection - ) + collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection) @title("Color Primitive") @@ -455,7 +442,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput): # Outputs collection: list[ConditioningField] = OutputField( - default_factory=list, description="The output conditioning tensors", ui_type=UIType.ConditioningCollection, ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 4efe30a3d9..fc224db14d 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -37,7 +37,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@title("SDXL Main Model Loader") +@title("SDXL Main Model") @tags("model", "sdxl") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -122,7 +122,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) -@title("SDXL Refiner Model Loader") +@title("SDXL Refiner Model") @tags("model", "sdxl", "refiner") class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" diff --git a/invokeai/app/services/config/__init__.py b/invokeai/app/services/config/__init__.py new file mode 100644 index 0000000000..6a42f9e08c --- /dev/null +++ b/invokeai/app/services/config/__init__.py @@ -0,0 +1,8 @@ +""" +Init file for InvokeAI configure package +""" + +from .invokeai_config import ( # noqa F401 + InvokeAIAppConfig, + get_invokeai_config, +) diff --git a/invokeai/app/services/config/base.py b/invokeai/app/services/config/base.py new file mode 100644 index 0000000000..b83621c708 --- /dev/null +++ b/invokeai/app/services/config/base.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team + +""" +Base class for the InvokeAI configuration system. +It defines a type of pydantic BaseSettings object that +is able to read and write from an omegaconf-based config file, +with overriding of settings from environment variables and/or +the command line. +""" + +from __future__ import annotations +import argparse +import os +import pydoc +import sys +from argparse import ArgumentParser +from omegaconf import OmegaConf, DictConfig, ListConfig +from pathlib import Path +from pydantic import BaseSettings +from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args + + +class PagingArgumentParser(argparse.ArgumentParser): + """ + A custom ArgumentParser that uses pydoc to page its output. + It also supports reading defaults from an init file. + """ + + def print_help(self, file=None): + text = self.format_help() + pydoc.pager(text) + + +class InvokeAISettings(BaseSettings): + """ + Runtime configuration settings in which default values are + read from an omegaconf .yaml file. + """ + + initconf: ClassVar[DictConfig] = None + argparse_groups: ClassVar[Dict] = {} + + def parse_args(self, argv: list = sys.argv[1:]): + parser = self.get_parser() + opt = parser.parse_args(argv) + for name in self.__fields__: + if name not in self._excluded(): + value = getattr(opt, name) + if isinstance(value, ListConfig): + value = list(value) + elif isinstance(value, DictConfig): + value = dict(value) + setattr(self, name, value) + + def to_yaml(self) -> str: + """ + Return a YAML string representing our settings. This can be used + as the contents of `invokeai.yaml` to restore settings later. + """ + cls = self.__class__ + type = get_args(get_type_hints(cls)["type"])[0] + field_dict = dict({type: dict()}) + for name, field in self.__fields__.items(): + if name in cls._excluded_from_yaml(): + continue + category = field.field_info.extra.get("category") or "Uncategorized" + value = getattr(self, name) + if category not in field_dict[type]: + field_dict[type][category] = dict() + # keep paths as strings to make it easier to read + field_dict[type][category][name] = str(value) if isinstance(value, Path) else value + conf = OmegaConf.create(field_dict) + return OmegaConf.to_yaml(conf) + + @classmethod + def add_parser_arguments(cls, parser): + if "type" in get_type_hints(cls): + settings_stanza = get_args(get_type_hints(cls)["type"])[0] + else: + settings_stanza = "Uncategorized" + + env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper() + + initconf = ( + cls.initconf.get(settings_stanza) + if cls.initconf and settings_stanza in cls.initconf + else OmegaConf.create() + ) + + # create an upcase version of the environment in + # order to achieve case-insensitive environment + # variables (the way Windows does) + upcase_environ = dict() + for key, value in os.environ.items(): + upcase_environ[key.upper()] = value + + fields = cls.__fields__ + cls.argparse_groups = {} + + for name, field in fields.items(): + if name not in cls._excluded(): + current_default = field.default + + category = field.field_info.extra.get("category", "Uncategorized") + env_name = env_prefix + "_" + name + if category in initconf and name in initconf.get(category): + field.default = initconf.get(category).get(name) + if env_name.upper() in upcase_environ: + field.default = upcase_environ[env_name.upper()] + cls.add_field_argument(parser, name, field) + + field.default = current_default + + @classmethod + def cmd_name(self, command_field: str = "type") -> str: + hints = get_type_hints(self) + if command_field in hints: + return get_args(hints[command_field])[0] + else: + return "Uncategorized" + + @classmethod + def get_parser(cls) -> ArgumentParser: + parser = PagingArgumentParser( + prog=cls.cmd_name(), + description=cls.__doc__, + ) + cls.add_parser_arguments(parser) + return parser + + @classmethod + def add_subparser(cls, parser: argparse.ArgumentParser): + parser.add_parser(cls.cmd_name(), help=cls.__doc__) + + @classmethod + def _excluded(self) -> List[str]: + # internal fields that shouldn't be exposed as command line options + return ["type", "initconf"] + + @classmethod + def _excluded_from_yaml(self) -> List[str]: + # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options + return [ + "type", + "initconf", + "version", + "from_file", + "model", + "root", + "max_cache_size", + "max_vram_cache_size", + "always_use_cpu", + "free_gpu_mem", + "xformers_enabled", + "tiled_decode", + ] + + class Config: + env_file_encoding = "utf-8" + arbitrary_types_allowed = True + case_sensitive = True + + @classmethod + def add_field_argument(cls, command_parser, name: str, field, default_override=None): + field_type = get_type_hints(cls).get(name) + default = ( + default_override + if default_override is not None + else field.default + if field.default_factory is None + else field.default_factory() + ) + if category := field.field_info.extra.get("category"): + if category not in cls.argparse_groups: + cls.argparse_groups[category] = command_parser.add_argument_group(category) + argparse_group = cls.argparse_groups[category] + else: + argparse_group = command_parser + + if get_origin(field_type) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str + + argparse_group.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=default, + choices=allowed_values, + help=field.field_info.description, + ) + + elif get_origin(field_type) == Union: + argparse_group.add_argument( + f"--{name}", + dest=name, + type=int_or_float_or_str, + default=default, + help=field.field_info.description, + ) + + elif get_origin(field_type) == list: + argparse_group.add_argument( + f"--{name}", + dest=name, + nargs="*", + type=field.type_, + default=default, + action=argparse.BooleanOptionalAction if field.type_ == bool else "store", + help=field.field_info.description, + ) + else: + argparse_group.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=default, + action=argparse.BooleanOptionalAction if field.type_ == bool else "store", + help=field.field_info.description, + ) + + +def int_or_float_or_str(value: str) -> Union[int, float, str]: + """ + Workaround for argparse type checking. + """ + try: + return int(value) + except Exception as e: # noqa F841 + pass + try: + return float(value) + except Exception as e: # noqa F841 + pass + return str(value) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config/invokeai_config.py similarity index 63% rename from invokeai/app/services/config.py rename to invokeai/app/services/config/invokeai_config.py index a9e5bbee98..728fe188b5 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -10,37 +10,49 @@ categories returned by `invokeai --help`. The file looks like this: [file: invokeai.yaml] InvokeAI: - Paths: - root: /home/lstein/invokeai-main - conf_path: configs/models.yaml - legacy_conf_dir: configs/stable-diffusion - outdir: outputs - autoimport_dir: null - Models: - model: stable-diffusion-1.5 - embeddings: true - Memory/Performance: - xformers_enabled: false - sequential_guidance: false - precision: float16 - max_cache_size: 6 - max_vram_cache_size: 0.5 - always_use_cpu: false - free_gpu_mem: false - Features: - esrgan: true - patchmatch: true - internet_available: true - log_tokenization: false Web Server: host: 127.0.0.1 - port: 8081 + port: 9090 allow_origins: [] allow_credentials: true allow_methods: - '*' allow_headers: - '*' + Features: + esrgan: true + internet_available: true + log_tokenization: false + patchmatch: true + ignore_missing_core_models: false + Paths: + autoimport_dir: autoimport + lora_dir: null + embedding_dir: null + controlnet_dir: null + conf_path: configs/models.yaml + models_dir: models + legacy_conf_dir: configs/stable-diffusion + db_dir: databases + outdir: /home/lstein/invokeai-main/outputs + use_memory_db: false + Logging: + log_handlers: + - console + log_format: plain + log_level: info + Model Cache: + ram: 13.5 + vram: 0.25 + lazy_offload: true + Device: + device: auto + precision: auto + Generation: + sequential_guidance: false + attention_type: xformers + attention_slice_size: auto + force_tiled_decode: false The default name of the configuration file is `invokeai.yaml`, located in INVOKEAI_ROOT. You can replace supersede this by providing any @@ -54,24 +66,23 @@ InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv` at initialization time. You may pass a list of strings in the optional `argv` argument to use instead of the system argv: - conf.parse_args(argv=['--xformers_enabled']) + conf.parse_args(argv=['--log_tokenization']) It is also possible to set a value at initialization time. However, if you call parse_args() it may be overwritten. - conf = InvokeAIAppConfig(xformers_enabled=True) - conf.parse_args(argv=['--no-xformers']) - conf.xformers_enabled + conf = InvokeAIAppConfig(log_tokenization=True) + conf.parse_args(argv=['--no-log_tokenization']) + conf.log_tokenization # False - To avoid this, use `get_config()` to retrieve the application-wide configuration object. This will retain any properties set at object creation time: - conf = InvokeAIAppConfig.get_config(xformers_enabled=True) - conf.parse_args(argv=['--no-xformers']) - conf.xformers_enabled + conf = InvokeAIAppConfig.get_config(log_tokenization=True) + conf.parse_args(argv=['--no-log_tokenization']) + conf.log_tokenization # True Any setting can be overwritten by setting an environment variable of @@ -93,7 +104,7 @@ Typical usage at the top level file: # get global configuration and print its cache size conf = InvokeAIAppConfig.get_config() conf.parse_args() - print(conf.max_cache_size) + print(conf.ram_cache_size) Typical usage in a backend module: @@ -101,8 +112,7 @@ Typical usage in a backend module: # get global configuration and print its cache size value conf = InvokeAIAppConfig.get_config() - print(conf.max_cache_size) - + print(conf.ram_cache_size) Computed properties: @@ -159,15 +169,13 @@ two configs are kept in separate sections of the config file: """ from __future__ import annotations -import argparse -import pydoc import os -import sys -from argparse import ArgumentParser -from omegaconf import OmegaConf, DictConfig, ListConfig +from omegaconf import OmegaConf, DictConfig from pathlib import Path -from pydantic import BaseSettings, Field, parse_obj_as -from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args +from pydantic import Field, parse_obj_as +from typing import ClassVar, Dict, List, Literal, Union, Optional, get_type_hints + +from .base import InvokeAISettings INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") @@ -175,195 +183,6 @@ LEGACY_INIT_FILE = Path("invokeai.init") DEFAULT_MAX_VRAM = 0.5 -class InvokeAISettings(BaseSettings): - """ - Runtime configuration settings in which default values are - read from an omegaconf .yaml file. - """ - - initconf: ClassVar[DictConfig] = None - argparse_groups: ClassVar[Dict] = {} - - def parse_args(self, argv: list = sys.argv[1:]): - parser = self.get_parser() - opt = parser.parse_args(argv) - for name in self.__fields__: - if name not in self._excluded(): - value = getattr(opt, name) - if isinstance(value, ListConfig): - value = list(value) - elif isinstance(value, DictConfig): - value = dict(value) - setattr(self, name, value) - - def to_yaml(self) -> str: - """ - Return a YAML string representing our settings. This can be used - as the contents of `invokeai.yaml` to restore settings later. - """ - cls = self.__class__ - type = get_args(get_type_hints(cls)["type"])[0] - field_dict = dict({type: dict()}) - for name, field in self.__fields__.items(): - if name in cls._excluded_from_yaml(): - continue - category = field.field_info.extra.get("category") or "Uncategorized" - value = getattr(self, name) - if category not in field_dict[type]: - field_dict[type][category] = dict() - # keep paths as strings to make it easier to read - field_dict[type][category][name] = str(value) if isinstance(value, Path) else value - conf = OmegaConf.create(field_dict) - return OmegaConf.to_yaml(conf) - - @classmethod - def add_parser_arguments(cls, parser): - if "type" in get_type_hints(cls): - settings_stanza = get_args(get_type_hints(cls)["type"])[0] - else: - settings_stanza = "Uncategorized" - - env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper() - - initconf = ( - cls.initconf.get(settings_stanza) - if cls.initconf and settings_stanza in cls.initconf - else OmegaConf.create() - ) - - # create an upcase version of the environment in - # order to achieve case-insensitive environment - # variables (the way Windows does) - upcase_environ = dict() - for key, value in os.environ.items(): - upcase_environ[key.upper()] = value - - fields = cls.__fields__ - cls.argparse_groups = {} - - for name, field in fields.items(): - if name not in cls._excluded(): - current_default = field.default - - category = field.field_info.extra.get("category", "Uncategorized") - env_name = env_prefix + "_" + name - if category in initconf and name in initconf.get(category): - field.default = initconf.get(category).get(name) - if env_name.upper() in upcase_environ: - field.default = upcase_environ[env_name.upper()] - cls.add_field_argument(parser, name, field) - - field.default = current_default - - @classmethod - def cmd_name(self, command_field: str = "type") -> str: - hints = get_type_hints(self) - if command_field in hints: - return get_args(hints[command_field])[0] - else: - return "Uncategorized" - - @classmethod - def get_parser(cls) -> ArgumentParser: - parser = PagingArgumentParser( - prog=cls.cmd_name(), - description=cls.__doc__, - ) - cls.add_parser_arguments(parser) - return parser - - @classmethod - def add_subparser(cls, parser: argparse.ArgumentParser): - parser.add_parser(cls.cmd_name(), help=cls.__doc__) - - @classmethod - def _excluded(self) -> List[str]: - # internal fields that shouldn't be exposed as command line options - return ["type", "initconf"] - - @classmethod - def _excluded_from_yaml(self) -> List[str]: - # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options - return [ - "type", - "initconf", - "version", - "from_file", - "model", - "root", - ] - - class Config: - env_file_encoding = "utf-8" - arbitrary_types_allowed = True - case_sensitive = True - - @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): - field_type = get_type_hints(cls).get(name) - default = ( - default_override - if default_override is not None - else field.default - if field.default_factory is None - else field.default_factory() - ) - if category := field.field_info.extra.get("category"): - if category not in cls.argparse_groups: - cls.argparse_groups[category] = command_parser.add_argument_group(category) - argparse_group = cls.argparse_groups[category] - else: - argparse_group = command_parser - - if get_origin(field_type) == Literal: - allowed_values = get_args(field.type_) - allowed_types = set() - for val in allowed_values: - allowed_types.add(type(val)) - allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore - - argparse_group.add_argument( - f"--{name}", - dest=name, - type=field_type, - default=default, - choices=allowed_values, - help=field.field_info.description, - ) - - elif get_origin(field_type) == list: - argparse_group.add_argument( - f"--{name}", - dest=name, - nargs="*", - type=field.type_, - default=default, - action=argparse.BooleanOptionalAction if field.type_ == bool else "store", - help=field.field_info.description, - ) - else: - argparse_group.add_argument( - f"--{name}", - dest=name, - type=field.type_, - default=default, - action=argparse.BooleanOptionalAction if field.type_ == bool else "store", - help=field.field_info.description, - ) - - -def _find_root() -> Path: - venv = Path(os.environ.get("VIRTUAL_ENV") or ".") - if os.environ.get("INVOKEAI_ROOT"): - root = Path(os.environ["INVOKEAI_ROOT"]) - elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): - root = (venv.parent).resolve() - else: - root = Path("~/invokeai").expanduser().resolve() - return root - - class InvokeAIAppConfig(InvokeAISettings): """ Generate images using Stable Diffusion. Use "invokeai" to launch @@ -378,6 +197,8 @@ class InvokeAIAppConfig(InvokeAISettings): # fmt: off type: Literal["InvokeAI"] = "InvokeAI" + + # WEB host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') port : int = Field(default=9090, description="Port to bind to", category='Web Server') allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server') @@ -385,20 +206,14 @@ class InvokeAIAppConfig(InvokeAISettings): allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server') allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server') + # FEATURES esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') + ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features') - always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') - free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') - max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') - precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance') - sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') - xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') - tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') - + # PATHS root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths') autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths') lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths') @@ -409,16 +224,41 @@ class InvokeAIAppConfig(InvokeAISettings): legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths') outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') - from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features') + from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') + # LOGGING log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") + + # CACHE + ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) + vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", ) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", ) + + # DEVICE + device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", ) + precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", ) + + # GENERATION + sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", ) + attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", ) + attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) + force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) + + # DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES + always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') + free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance') + max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') + max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') + xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') + tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') + + # See InvokeAIAppConfig subclass below for CACHE and DEVICE categories # fmt: on class Config: @@ -541,11 +381,6 @@ class InvokeAIAppConfig(InvokeAISettings): """Return true if precision set to float32""" return self.precision == "float32" - @property - def disable_xformers(self) -> bool: - """Return true if xformers_enabled is false""" - return not self.xformers_enabled - @property def try_patchmatch(self) -> bool: """Return true if patchmatch true""" @@ -561,6 +396,27 @@ class InvokeAIAppConfig(InvokeAISettings): """invisible watermark node is always active and disabled from Web UIe""" return True + @property + def ram_cache_size(self) -> float: + return self.max_cache_size or self.ram + + @property + def vram_cache_size(self) -> float: + return self.max_vram_cache_size or self.vram + + @property + def use_cpu(self) -> bool: + return self.always_use_cpu or self.device == "cpu" + + @property + def disable_xformers(self) -> bool: + """ + Return true if enable_xformers is false (reversed logic) + and attention type is not set to xformers. + """ + disabled_in_config = not self.xformers_enabled + return disabled_in_config and self.attention_type != "xformers" + @staticmethod def find_root() -> Path: """ @@ -570,19 +426,19 @@ class InvokeAIAppConfig(InvokeAISettings): return _find_root() -class PagingArgumentParser(argparse.ArgumentParser): - """ - A custom ArgumentParser that uses pydoc to page its output. - It also supports reading defaults from an init file. - """ - - def print_help(self, file=None): - text = self.format_help() - pydoc.pager(text) - - def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: """ Legacy function which returns InvokeAIAppConfig.get_config() """ return InvokeAIAppConfig.get_config(**kwargs) + + +def _find_root() -> Path: + venv = Path(os.environ.get("VIRTUAL_ENV") or ".") + if os.environ.get("INVOKEAI_ROOT"): + root = Path(os.environ["INVOKEAI_ROOT"]) + elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]): + root = (venv.parent).resolve() + else: + root = Path("~/invokeai").expanduser().resolve() + return root diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py index 7135e031b0..5e1a594b91 100644 --- a/invokeai/app/services/default_graphs.py +++ b/invokeai/app/services/default_graphs.py @@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph: description="Converts text to an image", graph=Graph( nodes={ - "width": IntegerInvocation(id="width", a=512), - "height": IntegerInvocation(id="height", a=512), - "seed": IntegerInvocation(id="seed", a=-1), + "width": IntegerInvocation(id="width", value=512), + "height": IntegerInvocation(id="height", value=512), + "seed": IntegerInvocation(id="seed", value=-1), "3": NoiseInvocation(id="3"), "4": CompelInvocation(id="4"), "5": CompelInvocation(id="5"), @@ -29,15 +29,15 @@ def create_text_to_image() -> LibraryGraph: }, edges=[ Edge( - source=EdgeConnection(node_id="width", field="a"), + source=EdgeConnection(node_id="width", field="value"), destination=EdgeConnection(node_id="3", field="width"), ), Edge( - source=EdgeConnection(node_id="height", field="a"), + source=EdgeConnection(node_id="height", field="value"), destination=EdgeConnection(node_id="3", field="height"), ), Edge( - source=EdgeConnection(node_id="seed", field="a"), + source=EdgeConnection(node_id="seed", field="value"), destination=EdgeConnection(node_id="3", field="seed"), ), Edge( @@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph: exposed_inputs=[ ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"), ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"), - ExposedNodeInput(node_path="width", field="a", alias="width"), - ExposedNodeInput(node_path="height", field="a", alias="height"), - ExposedNodeInput(node_path="seed", field="a", alias="seed"), + ExposedNodeInput(node_path="width", field="value", alias="width"), + ExposedNodeInput(node_path="height", field="value", alias="height"), + ExposedNodeInput(node_path="seed", field="value", alias="seed"), ], exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")], ) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index e8557c40f7..b42d128b51 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats GIG = 1073741824 +@dataclass +class NodeStats: + """Class for tracking execution stats of an invocation node""" + + calls: int = 0 + time_used: float = 0.0 # seconds + max_vram: float = 0.0 # GB + cache_hits: int = 0 + cache_misses: int = 0 + cache_high_watermark: int = 0 + + +@dataclass +class NodeLog: + """Class for tracking node usage""" + + # {node_type => NodeStats} + nodes: Dict[str, NodeStats] = field(default_factory=dict) + + class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" + graph_execution_manager: ItemStorageABC["GraphExecutionState"] + # {graph_id => NodeLog} + _stats: Dict[str, NodeLog] + _cache_stats: Dict[str, CacheStats] + ram_used: float + ram_changed: float + @abstractmethod def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): """ @@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC): invocation_type: str, time_used: float, vram_used: float, - ram_used: float, - ram_changed: float, ): """ Add timing information on execution of a node. Usually @@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC): :param invocation_type: String literal type of the node :param time_used: Time used by node's exection (sec) :param vram_used: Maximum VRAM used during exection (GB) - :param ram_used: Current RAM available (GB) - :param ram_changed: Change in RAM usage over course of the run (GB) """ pass @@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC): """ pass + @abstractmethod + def update_mem_stats( + self, + ram_used: float, + ram_changed: float, + ): + """ + Update the collector with RAM memory usage info. -@dataclass -class NodeStats: - """Class for tracking execution stats of an invocation node""" - - calls: int = 0 - time_used: float = 0.0 # seconds - max_vram: float = 0.0 # GB - cache_hits: int = 0 - cache_misses: int = 0 - cache_high_watermark: int = 0 - - -@dataclass -class NodeLog: - """Class for tracking node usage""" - - # {node_type => NodeStats} - nodes: Dict[str, NodeStats] = field(default_factory=dict) + :param ram_used: How much RAM is currently in use. + :param ram_changed: How much RAM changed since last generation. + """ + pass class InvocationStatsService(InvocationStatsServiceBase): @@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase): class StatsContext: """Context manager for collecting statistics.""" - invocation: BaseInvocation = None - collector: "InvocationStatsServiceBase" = None - graph_id: str = None - start_time: int = 0 - ram_used: int = 0 - model_manager: ModelManagerService = None + invocation: BaseInvocation + collector: "InvocationStatsServiceBase" + graph_id: str + start_time: float + ram_used: int + model_manager: ModelManagerService def __init__( self, @@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase): self.invocation = invocation self.collector = collector self.graph_id = graph_id - self.start_time = 0 + self.start_time = 0.0 self.ram_used = 0 self.model_manager = model_manager @@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase): ) self.collector.update_invocation_stats( graph_id=self.graph_id, - invocation_type=self.invocation.type, + invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations time_used=time.time() - self.start_time, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, ) @@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase): graph_execution_state_id: str, model_manager: ModelManagerService, ) -> StatsContext: - """ - Return a context object that will capture the statistics. - :param invocation: BaseInvocation object from the current graph. - :param graph_execution_state: GraphExecutionState object from the current session. - """ if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() self._cache_stats[graph_execution_state_id] = CacheStats() @@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase): self._stats = {} def reset_stats(self, graph_execution_id: str): - """Zero the statistics for the indicated graph.""" try: self._stats.pop(graph_execution_id) except KeyError: @@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase): ram_used: float, ram_changed: float, ): - """ - Update the collector with RAM memory usage info. - - :param ram_used: How much RAM is currently in use. - :param ram_changed: How much RAM changed since last generation. - """ self.ram_used = ram_used self.ram_changed = ram_changed @@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase): time_used: float, vram_used: float, ): - """ - Add timing information on execution of a node. Usually - used internally. - :param graph_id: ID of the graph that is currently executing - :param invocation_type: String literal type of the node - :param time_used: Time used by node's exection (sec) - :param vram_used: Maximum VRAM used during exection (GB) - :param ram_used: Current RAM available (GB) - :param ram_changed: Change in RAM usage over course of the run (GB) - """ if not self._stats[graph_id].nodes.get(invocation_type): self._stats[graph_id].nodes[invocation_type] = NodeStats() stats = self._stats[graph_id].nodes[invocation_type] @@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase): stats.max_vram = max(stats.max_vram, vram_used) def log_stats(self): - """ - Send the statistics to the system logger at the info level. - Stats will only be printed when the execution of the graph - is complete. - """ completed = set() + errored = set() for graph_id, node_log in self._stats.items(): - current_graph_state = self.graph_execution_manager.get(graph_id) + try: + current_graph_state = self.graph_execution_manager.get(graph_id) + except Exception: + errored.add(graph_id) + continue + if not current_graph_state.is_complete(): continue @@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase): for graph_id in completed: del self._stats[graph_id] del self._cache_stats[graph_id] + + for graph_id in errored: + del self._stats[graph_id] + del self._cache_stats[graph_id] diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 675bc71257..11ebab7938 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -330,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase): # configuration value. If present, then the # cache size is set to 2.5 GB times # the number of max_loaded_models. Otherwise - # use new `max_cache_size` config setting - max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5 + # use new `ram_cache_size` config setting + max_cache_size = config.ram_cache_size logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") diff --git a/invokeai/backend/image_util/lama.py b/invokeai/backend/image_util/lama.py new file mode 100644 index 0000000000..2ea22b6fa3 --- /dev/null +++ b/invokeai/backend/image_util/lama.py @@ -0,0 +1,56 @@ +import gc +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from invokeai.app.services.config import get_invokeai_config +from invokeai.backend.util.devices import choose_torch_device + + +def norm_img(np_img): + if len(np_img.shape) == 2: + np_img = np_img[:, :, np.newaxis] + np_img = np.transpose(np_img, (2, 0, 1)) + np_img = np_img.astype("float32") / 255 + return np_img + + +def load_jit_model(url_or_path, device): + model_path = url_or_path + print(f"Loading model from: {model_path}") + model = torch.jit.load(model_path, map_location="cpu").to(device) + model.eval() + return model + + +class LaMA: + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: + device = choose_torch_device() + model_location = get_invokeai_config().models_path / "core/misc/lama/lama.pt" + model = load_jit_model(model_location, device) + + image = np.asarray(input_image.convert("RGB")) + image = norm_img(image) + + mask = input_image.split()[-1] + mask = np.asarray(mask) + mask = np.invert(mask) + mask = norm_img(mask) + + mask = (mask > 0) * 1 + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + + with torch.inference_mode(): + infilled_image = model(image, mask) + + infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() + infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") + infilled_image = Image.fromarray(infilled_image) + + del model + gc.collect() + + return infilled_image diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index d7ecb41e9b..7925066562 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -21,6 +21,7 @@ from argparse import Namespace from enum import Enum from pathlib import Path from shutil import get_terminal_size +from typing import get_type_hints, get_args, Any from urllib import request import npyscreen @@ -50,6 +51,7 @@ from invokeai.frontend.install.model_install import addModelsForm, process_and_e # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( SingleSelectColumns, + MultiSelectColumns, CenteredButtonPress, FileBox, set_min_terminal_size, @@ -71,6 +73,10 @@ warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() +def get_literal_fields(field) -> list[Any]: + return get_args(get_type_hints(InvokeAIAppConfig).get(field)) + + # --------------------------globals----------------------- config = InvokeAIAppConfig.get_config() @@ -80,7 +86,11 @@ Model_dir = "models" Default_config_file = config.model_conf_path SD_Configs = config.legacy_conf_path -PRECISION_CHOICES = ["auto", "float16", "float32"] +PRECISION_CHOICES = get_literal_fields("precision") +DEVICE_CHOICES = get_literal_fields("device") +ATTENTION_CHOICES = get_literal_fields("attention_type") +ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size") +GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"] GB = 1073741824 # GB in bytes HAS_CUDA = torch.cuda.is_available() _, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) @@ -311,6 +321,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage): Use ctrl-N and ctrl-P to move to the ext and

revious fields. Use cursor arrows to make a checkbox selection, and space to toggle. """ + self.nextrely -= 1 for i in textwrap.wrap(label, width=window_width - 6): self.add_widget_intelligent( npyscreen.FixedText, @@ -337,76 +348,129 @@ Use cursor arrows to make a checkbox selection, and space to toggle. use_two_lines=False, scroll_exit=True, ) - self.nextrely += 1 - self.add_widget_intelligent( - npyscreen.TitleFixedText, - name="GPU Management", - begin_entry_at=0, - editable=False, - color="CONTROL", - scroll_exit=True, - ) - self.nextrely -= 1 - self.free_gpu_mem = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Free GPU memory after each generation", - value=old_opts.free_gpu_mem, - max_width=45, - relx=5, - scroll_exit=True, - ) - self.nextrely -= 1 - self.xformers_enabled = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Enable xformers support", - value=old_opts.xformers_enabled, - max_width=30, - relx=50, - scroll_exit=True, - ) - self.nextrely -= 1 - self.always_use_cpu = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Force CPU to be used on GPU systems", - value=old_opts.always_use_cpu, - relx=80, - scroll_exit=True, - ) + + # old settings for defaults precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto") + device = old_opts.device + attention_type = old_opts.attention_type + attention_slice_size = old_opts.attention_slice_size + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, - name="Floating Point Precision", + name="Image Generation Options:", + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 2 + self.generation_options = self.add_widget_intelligent( + MultiSelectColumns, + columns=3, + values=GENERATION_OPT_CHOICES, + value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)], + relx=30, + max_height=2, + max_width=80, + scroll_exit=True, + ) + + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Floating Point Precision:", begin_entry_at=0, editable=False, color="CONTROL", scroll_exit=True, ) - self.nextrely -= 1 + self.nextrely -= 2 self.precision = self.add_widget_intelligent( SingleSelectColumns, - columns=3, + columns=len(PRECISION_CHOICES), name="Precision", values=PRECISION_CHOICES, value=PRECISION_CHOICES.index(precision), begin_entry_at=3, max_height=2, + relx=30, + max_width=56, + scroll_exit=True, + ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Generation Device:", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 2 + self.device = self.add_widget_intelligent( + SingleSelectColumns, + columns=len(DEVICE_CHOICES), + values=DEVICE_CHOICES, + value=DEVICE_CHOICES.index(device), + begin_entry_at=3, + relx=30, + max_height=2, + max_width=60, + scroll_exit=True, + ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Attention Type:", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 2 + self.attention_type = self.add_widget_intelligent( + SingleSelectColumns, + columns=len(ATTENTION_CHOICES), + values=ATTENTION_CHOICES, + value=ATTENTION_CHOICES.index(attention_type), + begin_entry_at=3, + max_height=2, + relx=30, max_width=80, scroll_exit=True, ) - self.nextrely += 1 + self.attention_type.on_changed = self.show_hide_slice_sizes + self.attention_slice_label = self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Attention Slice Size:", + relx=5, + editable=False, + hidden=attention_type != "sliced", + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 2 + self.attention_slice_size = self.add_widget_intelligent( + SingleSelectColumns, + columns=len(ATTENTION_SLICE_CHOICES), + values=ATTENTION_SLICE_CHOICES, + value=ATTENTION_SLICE_CHOICES.index(attention_slice_size), + relx=30, + hidden=attention_type != "sliced", + max_height=2, + max_width=110, + scroll_exit=True, + ) + self.add_widget_intelligent( npyscreen.TitleFixedText, - name="RAM cache size (GB). Make this at least large enough to hold a single full model.", + name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.", begin_entry_at=0, editable=False, color="CONTROL", scroll_exit=True, ) self.nextrely -= 1 - self.max_cache_size = self.add_widget_intelligent( + self.ram = self.add_widget_intelligent( npyscreen.Slider, - value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5), + value=clip(old_opts.ram_cache_size, range=(3.0, MAX_RAM), step=0.5), out_of=round(MAX_RAM), lowest=0.0, step=0.5, @@ -417,16 +481,16 @@ Use cursor arrows to make a checkbox selection, and space to toggle. self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, - name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.", + name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.", begin_entry_at=0, editable=False, color="CONTROL", scroll_exit=True, ) self.nextrely -= 1 - self.max_vram_cache_size = self.add_widget_intelligent( + self.vram = self.add_widget_intelligent( npyscreen.Slider, - value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25), + value=clip(old_opts.vram_cache_size, range=(0, MAX_VRAM), step=0.25), out_of=round(MAX_VRAM * 2) / 2, lowest=0.0, relx=8, @@ -434,7 +498,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle. scroll_exit=True, ) else: - self.max_vram_cache_size = DummyWidgetValue.zero + self.vram_cache_size = DummyWidgetValue.zero self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -490,6 +554,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS when_pressed_function=self.on_ok, ) + def show_hide_slice_sizes(self, value): + show = ATTENTION_CHOICES[value[0]] == "sliced" + self.attention_slice_label.hidden = not show + self.attention_slice_size.hidden = not show + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -523,12 +592,9 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS new_opts = Namespace() for attr in [ + "ram", + "vram", "outdir", - "free_gpu_mem", - "max_cache_size", - "max_vram_cache_size", - "xformers_enabled", - "always_use_cpu", ]: setattr(new_opts, attr, getattr(self, attr).value) @@ -541,6 +607,12 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS new_opts.hf_token = self.hf_token.value new_opts.license_acceptance = self.license_acceptance.value new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] + new_opts.device = DEVICE_CHOICES[self.device.value[0]] + new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]] + new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]] + generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value] + for v in GENERATION_OPT_CHOICES: + setattr(new_opts, v, v in generation_options) return new_opts diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 5fd3669911..8118e28abb 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -20,11 +20,36 @@ import re from contextlib import nullcontext from io import BytesIO -from typing import Optional, Union from pathlib import Path +from typing import Optional, Union import requests import torch +from diffusers.models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available, is_omegaconf_available +from diffusers.utils.import_utils import BACKENDS_MAPPING +from picklescan.scanner import scan_file_path from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -37,35 +62,8 @@ from transformers import ( CLIPVisionModelWithProjection, ) -from diffusers.models import ( - AutoencoderKL, - ControlNetModel, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available -from diffusers.utils.import_utils import BACKENDS_MAPPING -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - -from invokeai.backend.util.logging import InvokeAILogger from invokeai.app.services.config import InvokeAIAppConfig - -from picklescan.scanner import scan_file_path +from invokeai.backend.util.logging import InvokeAILogger from .models import BaseModelType, ModelVariantType try: @@ -1221,9 +1219,6 @@ def download_from_original_stable_diffusion_ckpt( raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) if from_safetensors: - if not is_safetensors_available(): - raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - from safetensors.torch import load_file as safe_load checkpoint = safe_load(checkpoint_path, device="cpu") @@ -1662,9 +1657,6 @@ def download_controlnet_from_original_ckpt( from omegaconf import OmegaConf if from_safetensors: - if not is_safetensors_available(): - raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - from safetensors import safe_open checkpoint = {} @@ -1741,7 +1733,7 @@ def convert_ckpt_to_diffusers( pipe.save_pretrained( dump_path, - safe_serialization=use_safetensors and is_safetensors_available(), + safe_serialization=use_safetensors, ) @@ -1757,7 +1749,4 @@ def convert_controlnet_to_diffusers( """ pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) - pipe.save_pretrained( - dump_path, - safe_serialization=is_safetensors_available(), - ) + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 1b10554e69..d87bc03fb7 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -341,7 +341,8 @@ class ModelManager(object): self.logger = logger self.cache = ModelCache( max_cache_size=max_cache_size, - max_vram_cache_size=self.app_config.max_vram_cache_size, + max_vram_cache_size=self.app_config.vram_cache_size, + lazy_offloading=self.app_config.lazy_offload, execution_device=device_type, precision=precision, sequential_offload=sequential_offload, diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index cf7622a9aa..f5dc11b27b 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -5,7 +5,6 @@ from typing import Optional import safetensors import torch -from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf from invokeai.app.services.config import InvokeAIAppConfig @@ -175,5 +174,5 @@ def _convert_vae_ckpt_and_cache( vae_config=config, image_size=image_size, ) - vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available()) + vae_model.save_pretrained(output_path, safe_serialization=True) return output_path diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 0180830b76..63b0c78b51 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -33,7 +33,7 @@ from .diffusion import ( PostprocessingSettings, BasicConditioningInfo, ) -from ..util import normalize_device +from ..util import normalize_device, auto_detect_slice_size @dataclass @@ -291,6 +291,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if xformers is available, use it, otherwise use sliced attention. """ config = InvokeAIAppConfig.get_config() + if config.attention_type == "xformers": + self.enable_xformers_memory_efficient_attention() + return + elif config.attention_type == "sliced": + slice_size = config.attention_slice_size + if slice_size == "auto": + slice_size = auto_detect_slice_size(latents) + elif slice_size == "balanced": + slice_size = "auto" + self.enable_attention_slicing(slice_size=slice_size) + return + elif config.attention_type == "normal": + self.disable_attention_slicing() + return + elif config.attention_type == "torch-sdp": + raise Exception("torch-sdp attention slicing not yet implemented") + + # the remainder if this code is called when attention_type=='auto' if self.unet.device.type == "cuda": if is_xformers_available() and not config.disable_xformers: self.enable_xformers_memory_efficient_attention() diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 30bb0efc15..b4e1c6e3a3 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -11,4 +11,11 @@ from .devices import ( # noqa: F401 torch_dtype, ) from .log import write_log # noqa: F401 -from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401 +from .util import ( # noqa: F401 + ask_user, + download_with_resume, + instantiate_from_config, + url_attachment_name, + Chdir, +) +from .attention import auto_detect_slice_size # noqa: F401 diff --git a/invokeai/backend/util/attention.py b/invokeai/backend/util/attention.py new file mode 100644 index 0000000000..a821464394 --- /dev/null +++ b/invokeai/backend/util/attention.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team +""" +Utility routine used for autodetection of optimal slice size +for attention mechanism. +""" +import torch +import psutil + + +def auto_detect_slice_size(latents: torch.Tensor) -> str: + bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 + max_size_required_for_baddbmm = ( + 16 + * latents.size(dim=2) + * latents.size(dim=3) + * latents.size(dim=2) + * latents.size(dim=3) + * bytes_per_element_needed_for_baddbmm_duplication + ) + if latents.device.type in {"cpu", "mps"}: + mem_free = psutil.virtual_memory().free + elif latents.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(latents.device) + else: + raise ValueError(f"unrecognized device {latents.device}") + + if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): + return "max" + elif torch.backends.mps.is_available(): + return "max" + else: + return "balanced" diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 1827f295e4..bdaf3244f3 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -17,13 +17,17 @@ config = InvokeAIAppConfig.get_config() def choose_torch_device() -> torch.device: """Convenience routine for guessing which GPU device to run model on""" - if config.always_use_cpu: + if config.use_cpu: # legacy setting - force CPU return CPU_DEVICE - if torch.cuda.is_available(): - return torch.device("cuda") - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - return torch.device("mps") - return CPU_DEVICE + elif config.device == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return CPU_DEVICE + else: + return torch.device(config.device) def choose_precision(device: torch.device) -> str: diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 79b6280990..f7d1d044c8 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -17,8 +17,8 @@ from shutil import get_terminal_size from curses import BUTTON2_CLICKED, BUTTON3_CLICKED # minimum size for UIs -MIN_COLS = 130 -MIN_LINES = 38 +MIN_COLS = 150 +MIN_LINES = 40 class WindowTooSmallException(Exception): @@ -277,6 +277,9 @@ class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged): def h_cursor_line_right(self, ch): self.h_exit_down("bye bye") + def h_cursor_line_left(self, ch): + self.h_exit_up("bye bye") + class TextBoxInner(npyscreen.MultiLineEdit): def __init__(self, *args, **kwargs): @@ -324,55 +327,6 @@ class TextBoxInner(npyscreen.MultiLineEdit): if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED): self.h_paste() - # def update(self, clear=True): - # if clear: - # self.clear() - - # HEIGHT = self.height - # WIDTH = self.width - # # draw box. - # self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH) - # self.parent.curses_pad.hline( - # self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH - # ) - # self.parent.curses_pad.vline( - # self.rely, self.relx, curses.ACS_VLINE, self.height - # ) - # self.parent.curses_pad.vline( - # self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT - # ) - - # # draw corners - # self.parent.curses_pad.addch( - # self.rely, - # self.relx, - # curses.ACS_ULCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely, - # self.relx + WIDTH, - # curses.ACS_URCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely + HEIGHT, - # self.relx, - # curses.ACS_LLCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely + HEIGHT, - # self.relx + WIDTH, - # curses.ACS_LRCORNER, - # ) - - # # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work - # (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width) - # self.relx += 1 - # self.rely += 1 - # self.height -= 1 - # self.width -= 1 - # super().update(clear=False) - # (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width) - class TextBox(npyscreen.BoxTitle): _contained_widget = TextBoxInner diff --git a/invokeai/frontend/web/.eslintrc.js b/invokeai/frontend/web/.eslintrc.js index c48e08d45e..c2b1433a9a 100644 --- a/invokeai/frontend/web/.eslintrc.js +++ b/invokeai/frontend/web/.eslintrc.js @@ -9,8 +9,8 @@ module.exports = { 'plugin:@typescript-eslint/recommended', 'plugin:react/recommended', 'plugin:react-hooks/recommended', - 'plugin:prettier/recommended', 'plugin:react/jsx-runtime', + 'prettier', ], parser: '@typescript-eslint/parser', parserOptions: { @@ -23,6 +23,11 @@ module.exports = { plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'], root: true, rules: { + curly: 'error', + 'react/jsx-curly-brace-presence': [ + 'error', + { props: 'never', children: 'never' }, + ], 'react-hooks/exhaustive-deps': 'error', 'no-var': 'error', 'brace-style': 'error', @@ -34,7 +39,6 @@ module.exports = { 'warn', { varsIgnorePattern: '^_', argsIgnorePattern: '^_' }, ], - 'prettier/prettier': ['error', { endOfLine: 'auto' }], '@typescript-eslint/ban-ts-comment': 'warn', '@typescript-eslint/no-explicit-any': 'warn', '@typescript-eslint/no-empty-interface': [ diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 6c9db74bbc..e3f6dc48d7 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -29,12 +29,13 @@ "lint:eslint": "eslint --max-warnings=0 .", "lint:prettier": "prettier --check .", "lint:tsc": "tsc --noEmit", - "lint": "yarn run lint:eslint && yarn run lint:prettier && yarn run lint:tsc && yarn run lint:madge", + "lint": "concurrently -g -n eslint,prettier,tsc,madge -c cyan,green,magenta,yellow \"yarn run lint:eslint\" \"yarn run lint:prettier\" \"yarn run lint:tsc\" \"yarn run lint:madge\"", "fix": "eslint --fix . && prettier --loglevel warn --write . && tsc --noEmit", "lint-staged": "lint-staged", "postinstall": "patch-package && yarn run theme", "theme": "chakra-cli tokens src/theme/theme.ts", - "theme:watch": "chakra-cli tokens src/theme/theme.ts --watch" + "theme:watch": "chakra-cli tokens src/theme/theme.ts --watch", + "up": "yarn upgrade-interactive --latest" }, "madge": { "detectiveOptions": { @@ -54,7 +55,7 @@ }, "dependencies": { "@chakra-ui/anatomy": "^2.2.0", - "@chakra-ui/icons": "^2.0.19", + "@chakra-ui/icons": "^2.1.0", "@chakra-ui/react": "^2.8.0", "@chakra-ui/styled-system": "^2.9.1", "@chakra-ui/theme-tools": "^2.1.0", @@ -65,55 +66,55 @@ "@emotion/react": "^11.11.1", "@emotion/styled": "^11.11.0", "@floating-ui/react-dom": "^2.0.1", - "@fontsource-variable/inter": "^5.0.3", - "@fontsource/inter": "^5.0.3", - "@mantine/core": "^6.0.14", - "@mantine/form": "^6.0.15", - "@mantine/hooks": "^6.0.14", + "@fontsource-variable/inter": "^5.0.8", + "@fontsource/inter": "^5.0.8", + "@mantine/core": "^6.0.19", + "@mantine/form": "^6.0.19", + "@mantine/hooks": "^6.0.19", "@nanostores/react": "^0.7.1", "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", - "chakra-ui-contextmenu": "^1.0.5", "dateformat": "^5.0.3", - "downshift": "^7.6.0", - "formik": "^2.4.2", - "framer-motion": "^10.12.17", + "formik": "^2.4.3", + "framer-motion": "^10.16.1", "fuse.js": "^6.6.2", - "i18next": "^23.2.3", + "i18next": "^23.4.4", "i18next-browser-languagedetector": "^7.0.2", "i18next-http-backend": "^2.2.1", "konva": "^9.2.0", "lodash-es": "^4.17.21", "nanostores": "^0.9.2", - "openapi-fetch": "^0.6.1", + "new-github-issue-url": "^1.0.0", + "openapi-fetch": "^0.7.4", "overlayscrollbars": "^2.2.0", "overlayscrollbars-react": "^0.5.0", - "patch-package": "^7.0.0", + "patch-package": "^8.0.0", "query-string": "^8.1.0", - "re-resizable": "^6.9.9", "react": "^18.2.0", "react-colorful": "^5.6.1", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", - "react-hotkeys-hook": "4.4.0", - "react-i18next": "^13.0.1", + "react-error-boundary": "^4.0.11", + "react-hotkeys-hook": "4.4.1", + "react-i18next": "^13.1.2", "react-icons": "^4.10.1", "react-konva": "^18.2.10", - "react-redux": "^8.1.1", - "react-resizable-panels": "^0.0.52", + "react-redux": "^8.1.2", + "react-resizable-panels": "^0.0.55", "react-use": "^17.4.0", - "react-virtuoso": "^4.3.11", + "react-virtuoso": "^4.5.0", "react-zoom-pan-pinch": "^3.0.8", - "reactflow": "^11.7.4", + "reactflow": "^11.8.3", "redux-dynamic-middlewares": "^2.2.0", - "redux-remember": "^3.3.1", - "roarr": "^7.15.0", - "serialize-error": "^11.0.0", - "socket.io-client": "^4.7.0", + "redux-remember": "^4.0.1", + "roarr": "^7.15.1", + "serialize-error": "^11.0.1", + "socket.io-client": "^4.7.2", "use-debounce": "^9.0.4", "use-image": "^1.1.1", "uuid": "^9.0.0", - "zod": "^3.21.4" + "zod": "^3.22.2", + "zod-validation-error": "^1.5.0" }, "peerDependencies": { "@chakra-ui/cli": "^2.4.0", @@ -126,38 +127,36 @@ "@chakra-ui/cli": "^2.4.1", "@types/dateformat": "^5.0.0", "@types/lodash-es": "^4.14.194", - "@types/node": "^20.3.1", - "@types/react": "^18.2.14", + "@types/node": "^20.5.1", + "@types/react": "^18.2.20", "@types/react-dom": "^18.2.6", "@types/react-redux": "^7.1.25", "@types/react-transition-group": "^4.4.6", "@types/uuid": "^9.0.2", - "@typescript-eslint/eslint-plugin": "^5.60.0", - "@typescript-eslint/parser": "^5.60.0", + "@typescript-eslint/eslint-plugin": "^6.4.1", + "@typescript-eslint/parser": "^6.4.1", "@vitejs/plugin-react-swc": "^3.3.2", "axios": "^1.4.0", "babel-plugin-transform-imports": "^2.0.0", "concurrently": "^8.2.0", - "eslint": "^8.43.0", - "eslint-config-prettier": "^8.8.0", - "eslint-plugin-prettier": "^4.2.1", - "eslint-plugin-react": "^7.32.2", + "eslint": "^8.47.0", + "eslint-config-prettier": "^9.0.0", + "eslint-plugin-prettier": "^5.0.0", + "eslint-plugin-react": "^7.33.2", "eslint-plugin-react-hooks": "^4.6.0", "form-data": "^4.0.0", "husky": "^8.0.3", - "lint-staged": "^13.2.2", + "lint-staged": "^14.0.1", "madge": "^6.1.0", "openapi-types": "^12.1.3", - "openapi-typescript": "^6.2.8", - "openapi-typescript-codegen": "^0.24.0", + "openapi-typescript": "^6.5.2", "postinstall-postinstall": "^2.1.0", - "prettier": "^2.8.8", + "prettier": "^3.0.2", "rollup-plugin-visualizer": "^5.9.2", - "terser": "^5.18.1", "ts-toolbelt": "^9.6.0", - "vite": "^4.3.9", - "vite-plugin-css-injected-by-js": "^3.1.1", - "vite-plugin-dts": "^2.3.0", + "vite": "^4.4.9", + "vite-plugin-css-injected-by-js": "^3.3.0", + "vite-plugin-dts": "^3.5.2", "vite-plugin-eslint": "^1.8.1", "vite-tsconfig-paths": "^4.2.0", "yarn": "^1.22.19" diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index f41da82e07..e39f438146 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -19,7 +19,7 @@ "toggleAutoscroll": "Toggle autoscroll", "toggleLogViewer": "Toggle Log Viewer", "showGallery": "Show Gallery", - "showOptionsPanel": "Show Options Panel", + "showOptionsPanel": "Show Side Panel", "menu": "Menu" }, "common": { @@ -52,7 +52,7 @@ "img2img": "Image To Image", "unifiedCanvas": "Unified Canvas", "linear": "Linear", - "nodes": "Node Editor", + "nodes": "Workflow Editor", "batch": "Batch Manager", "modelManager": "Model Manager", "postprocessing": "Post Processing", @@ -95,7 +95,6 @@ "statusModelConverted": "Model Converted", "statusMergingModels": "Merging Models", "statusMergedModels": "Models Merged", - "pinOptionsPanel": "Pin Options Panel", "loading": "Loading", "loadingInvokeAI": "Loading Invoke AI", "random": "Random", @@ -116,7 +115,6 @@ "maintainAspectRatio": "Maintain Aspect Ratio", "autoSwitchNewImages": "Auto-Switch to New Images", "singleColumnLayout": "Single Column Layout", - "pinGallery": "Pin Gallery", "allImagesLoaded": "All Images Loaded", "loadMore": "Load More", "noImagesInGallery": "No Images to Display", @@ -133,6 +131,7 @@ "generalHotkeys": "General Hotkeys", "galleryHotkeys": "Gallery Hotkeys", "unifiedCanvasHotkeys": "Unified Canvas Hotkeys", + "nodesHotkeys": "Nodes Hotkeys", "invoke": { "title": "Invoke", "desc": "Generate an image" @@ -332,6 +331,10 @@ "acceptStagingImage": { "title": "Accept Staging Image", "desc": "Accept Current Staging Area Image" + }, + "addNodes": { + "title": "Add Nodes", + "desc": "Opens the add node menu" } }, "modelManager": { @@ -506,12 +509,9 @@ "maskAdjustmentsHeader": "Mask Adjustments", "maskBlur": "Mask Blur", "maskBlurMethod": "Mask Blur Method", - "seamPaintingHeader": "Seam Painting", - "seamSize": "Seam Size", - "seamBlur": "Seam Blur", - "seamSteps": "Seam Steps", - "seamStrength": "Seam Strength", - "seamThreshold": "Seam Threshold", + "coherencePassHeader": "Coherence Pass", + "coherenceSteps": "Coherence Pass Steps", + "coherenceStrength": "Coherence Pass Strength", "seamLowThreshold": "Low", "seamHighThreshold": "High", "scaleBeforeProcessing": "Scale Before Processing", @@ -572,7 +572,7 @@ "resetWebUI": "Reset Web UI", "resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.", "resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.", - "resetComplete": "Web UI has been reset. Refresh the page to reload.", + "resetComplete": "Web UI has been reset.", "consoleLogLevel": "Log Level", "shouldLogToConsole": "Console Logging", "developer": "Developer", @@ -715,11 +715,12 @@ "swapSizes": "Swap Sizes" }, "nodes": { - "reloadSchema": "Reload Schema", - "saveGraph": "Save Graph", - "loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)", - "clearGraph": "Clear Graph", - "clearGraphDesc": "Are you sure you want to clear all nodes?", + "reloadNodeTemplates": "Reload Node Templates", + "saveWorkflow": "Save Workflow", + "loadWorkflow": "Load Workflow", + "resetWorkflow": "Reset Workflow", + "resetWorkflowDesc": "Are you sure you want to reset this workflow?", + "resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.", "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "fitViewportNodes": "Fit View", diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index d105917e66..485cf6cbc3 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -27,22 +27,10 @@ async function main() { * field accepts connection input. If it does, we can make the field optional. */ - // Check if we are generating types for an invocation - const isInvocationPath = metadata.path.match( - /^#\/components\/schemas\/\w*Invocation$/ - ); - - const hasInvocationProperties = - schemaObject.properties && - ['id', 'is_intermediate', 'type'].every( - (prop) => prop in schemaObject.properties - ); - - if (isInvocationPath && hasInvocationProperties) { + if ('class' in schemaObject && schemaObject.class === 'invocation') { // We only want to make fields optional if they are required if (!Array.isArray(schemaObject?.required)) { - schemaObject.required = ['id', 'type']; - return; + schemaObject.required = []; } schemaObject.required.forEach((prop) => { @@ -61,19 +49,13 @@ async function main() { ); } }); - - schemaObject.required = [ - ...new Set(schemaObject.required.concat(['id', 'type'])), - ]; - return; } - // if ( - // 'input' in schemaObject && - // (schemaObject.input === 'any' || schemaObject.input === 'connection') - // ) { - // schemaObject.required = false; - // } + + // Check if we are generating types for an invocation output + if ('class' in schemaObject && schemaObject.class === 'output') { + // modify output types + } }, }); fs.writeFileSync(OUTPUT_FILE, types); diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index fa45ae93cd..c2cc4645b8 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -1,4 +1,4 @@ -import { Flex, Grid, Portal } from '@chakra-ui/react'; +import { Flex, Grid } from '@chakra-ui/react'; import { useLogger } from 'app/logging/useLogger'; import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -6,17 +6,15 @@ import { PartialAppConfig } from 'app/types/invokeai'; import ImageUploader from 'common/components/ImageUploader'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal'; -import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import SiteHeader from 'features/system/components/SiteHeader'; import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; -import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; -import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import InvokeTabs from 'features/ui/components/InvokeTabs'; -import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import i18n from 'i18n'; import { size } from 'lodash-es'; -import { ReactNode, memo, useEffect } from 'react'; +import { ReactNode, memo, useCallback, useEffect } from 'react'; +import { ErrorBoundary } from 'react-error-boundary'; +import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; @@ -30,8 +28,13 @@ interface Props { const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { const language = useAppSelector(languageSelector); - const logger = useLogger(); + const logger = useLogger('system'); const dispatch = useAppDispatch(); + const handleReset = useCallback(() => { + localStorage.clear(); + location.reload(); + return false; + }, []); useEffect(() => { i18n.changeLanguage(language); @@ -39,7 +42,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { useEffect(() => { if (size(config)) { - logger.info({ namespace: 'App', config }, 'Received config'); + logger.info({ config }, 'Received config'); dispatch(configChanged(config)); } }, [dispatch, config, logger]); @@ -49,7 +52,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { }, [dispatch]); return ( - <> + { - - - - - - - - - - + ); }; diff --git a/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx new file mode 100644 index 0000000000..76a34388eb --- /dev/null +++ b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx @@ -0,0 +1,97 @@ +import { Flex, Heading, Link, Text, useToast } from '@chakra-ui/react'; +import IAIButton from 'common/components/IAIButton'; +import newGithubIssueUrl from 'new-github-issue-url'; +import { memo, useCallback, useMemo } from 'react'; +import { FaCopy, FaExternalLinkAlt } from 'react-icons/fa'; +import { FaArrowRotateLeft } from 'react-icons/fa6'; +import { serializeError } from 'serialize-error'; + +type Props = { + error: Error; + resetErrorBoundary: () => void; +}; + +const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => { + const toast = useToast(); + + const handleCopy = useCallback(() => { + const text = JSON.stringify(serializeError(error), null, 2); + navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``); + toast({ + title: 'Error Copied', + }); + }, [error, toast]); + + const url = useMemo( + () => + newGithubIssueUrl({ + user: 'invoke-ai', + repo: 'InvokeAI', + template: 'BUG_REPORT.yml', + title: `[bug]: ${error.name}: ${error.message}`, + }), + [error.message, error.name] + ); + return ( + + + Something went wrong + + + {error.name}: {error.message} + + + + } + onClick={resetErrorBoundary} + > + Reset UI + + } onClick={handleCopy}> + Copy Error + + + }>Create Issue + + + + + ); +}; + +export default memo(AppErrorBoundaryFallback); diff --git a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts index bbe77dc698..ac48fcc7b1 100644 --- a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts +++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts @@ -1,30 +1,21 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { ctrlKeyPressed, metaKeyPressed, shiftKeyPressed, } from 'features/ui/store/hotkeysSlice'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { - setActiveTab, - toggleGalleryPanel, - toggleParametersPanel, - togglePinGalleryPanel, - togglePinParametersPanel, -} from 'features/ui/store/uiSlice'; +import { setActiveTab } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; import React, { memo } from 'react'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; const globalHotkeysSelector = createSelector( [stateSelector], - ({ hotkeys, ui }) => { + ({ hotkeys }) => { const { shift, ctrl, meta } = hotkeys; - const { shouldPinParametersPanel, shouldPinGallery } = ui; - return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel }; + return { shift, ctrl, meta }; }, { memoizeOptions: { @@ -41,9 +32,7 @@ const globalHotkeysSelector = createSelector( */ const GlobalHotkeys: React.FC = () => { const dispatch = useAppDispatch(); - const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } = - useAppSelector(globalHotkeysSelector); - const activeTabName = useAppSelector(activeTabNameSelector); + const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector); useHotkeys( '*', @@ -68,34 +57,6 @@ const GlobalHotkeys: React.FC = () => { [shift, ctrl, meta] ); - useHotkeys('o', () => { - dispatch(toggleParametersPanel()); - if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) { - dispatch(requestCanvasRescale()); - } - }); - - useHotkeys(['shift+o'], () => { - dispatch(togglePinParametersPanel()); - if (activeTabName === 'unifiedCanvas') { - dispatch(requestCanvasRescale()); - } - }); - - useHotkeys('g', () => { - dispatch(toggleGalleryPanel()); - if (activeTabName === 'unifiedCanvas' && shouldPinGallery) { - dispatch(requestCanvasRescale()); - } - }); - - useHotkeys(['shift+g'], () => { - dispatch(togglePinGalleryPanel()); - if (activeTabName === 'unifiedCanvas') { - dispatch(requestCanvasRescale()); - } - }); - useHotkeys('1', () => { dispatch(setActiveTab('txt2img')); }); @@ -112,6 +73,10 @@ const GlobalHotkeys: React.FC = () => { dispatch(setActiveTab('nodes')); }); + useHotkeys('5', () => { + dispatch(setActiveTab('modelManager')); + }); + return null; }; diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 621b196ae0..9bcc7c831b 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -3,7 +3,7 @@ import { createLocalStorageManager, extendTheme, } from '@chakra-ui/react'; -import { ReactNode, useEffect, useMemo } from 'react'; +import { ReactNode, memo, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { theme as invokeAITheme } from 'theme/theme'; @@ -46,4 +46,4 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { ); } -export default ThemeLocaleProvider; +export default memo(ThemeLocaleProvider); diff --git a/invokeai/frontend/web/src/app/components/Toaster.ts b/invokeai/frontend/web/src/app/components/Toaster.ts index dff2a7c7f5..9d7149023b 100644 --- a/invokeai/frontend/web/src/app/components/Toaster.ts +++ b/invokeai/frontend/web/src/app/components/Toaster.ts @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { toastQueueSelector } from 'features/system/store/systemSelectors'; import { addToast, clearToastQueue } from 'features/system/store/systemSlice'; import { MakeToastArg, makeToast } from 'features/system/util/makeToast'; -import { useCallback, useEffect } from 'react'; +import { memo, useCallback, useEffect } from 'react'; /** * Logical component. Watches the toast queue and makes toasts when the queue is not empty. @@ -44,4 +44,4 @@ export const useAppToaster = () => { return toaster; }; -export default Toaster; +export default memo(Toaster); diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index 7797b8dc92..2d7b8a7744 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -9,7 +9,7 @@ export const log = Roarr.child(BASE_CONTEXT); export const $logger = atom(Roarr.child(BASE_CONTEXT)); -type LoggerNamespace = +export type LoggerNamespace = | 'images' | 'models' | 'config' diff --git a/invokeai/frontend/web/src/app/logging/useLogger.ts b/invokeai/frontend/web/src/app/logging/useLogger.ts index 6c60bd4fd0..d31bcc2660 100644 --- a/invokeai/frontend/web/src/app/logging/useLogger.ts +++ b/invokeai/frontend/web/src/app/logging/useLogger.ts @@ -1,12 +1,17 @@ -import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { createLogWriter } from '@roarr/browser-log-writer'; import { useAppSelector } from 'app/store/storeHooks'; import { systemSelector } from 'features/system/store/systemSelectors'; import { isEqual } from 'lodash-es'; -import { useEffect } from 'react'; +import { useEffect, useMemo } from 'react'; import { ROARR, Roarr } from 'roarr'; -import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP } from './logger'; +import { + $logger, + BASE_CONTEXT, + LOG_LEVEL_MAP, + LoggerNamespace, + logger, +} from './logger'; const selector = createSelector( systemSelector, @@ -25,7 +30,7 @@ const selector = createSelector( } ); -export const useLogger = () => { +export const useLogger = (namespace: LoggerNamespace) => { const { consoleLogLevel, shouldLogToConsole } = useAppSelector(selector); // The provided Roarr browser log writer uses localStorage to config logging to console @@ -57,7 +62,7 @@ export const useLogger = () => { $logger.set(Roarr.child(newContext)); }, []); - const logger = useStore($logger); + const log = useMemo(() => logger(namespace), [namespace]); - return logger; + return log; }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index b419e98782..770c9fc11b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,13 +1,17 @@ import { logger } from 'app/logging/logger'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; -import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; +import { + controlNetImageChanged, + controlNetProcessedImageChanged, +} from 'features/controlNet/store/controlNetSlice'; import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; import { isModalOpenChanged } from 'features/deleteImageModal/store/slice'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; +import { isInvocationNode } from 'features/nodes/types/types'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; -import { clamp } from 'lodash-es'; +import { clamp, forEach } from 'lodash-es'; import { api } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; @@ -73,22 +77,61 @@ export const addRequestedSingleImageDeletionListener = () => { } // We need to reset the features where the image is in use - none of these work if their image(s) don't exist - if (imageUsage.isCanvasImage) { dispatch(resetCanvas()); } - if (imageUsage.isControlNetImage) { - dispatch(controlNetReset()); - } + imageDTOs.forEach((imageDTO) => { + // reset init image if we deleted it + if ( + getState().generation.initialImage?.imageName === imageDTO.image_name + ) { + dispatch(clearInitialImage()); + } - if (imageUsage.isInitialImage) { - dispatch(clearInitialImage()); - } + // reset controlNets that use the deleted images + forEach(getState().controlNet.controlNets, (controlNet) => { + if ( + controlNet.controlImage === imageDTO.image_name || + controlNet.processedControlImage === imageDTO.image_name + ) { + dispatch( + controlNetImageChanged({ + controlNetId: controlNet.controlNetId, + controlImage: null, + }) + ); + dispatch( + controlNetProcessedImageChanged({ + controlNetId: controlNet.controlNetId, + processedControlImage: null, + }) + ); + } + }); - if (imageUsage.isNodesImage) { - dispatch(nodeEditorReset()); - } + // reset nodes that use the deleted images + getState().nodes.nodes.forEach((node) => { + if (!isInvocationNode(node)) { + return; + } + + forEach(node.data.inputs, (input) => { + if ( + input.type === 'ImageField' && + input.value?.image_name === imageDTO.image_name + ) { + dispatch( + fieldImageValueChanged({ + nodeId: node.data.id, + fieldName: input.name, + value: undefined, + }) + ); + } + }); + }); + }); // Delete from server const { requestId } = dispatch( @@ -154,17 +197,58 @@ export const addRequestedMultipleImageDeletionListener = () => { dispatch(resetCanvas()); } - if (imagesUsage.some((i) => i.isControlNetImage)) { - dispatch(controlNetReset()); - } + imageDTOs.forEach((imageDTO) => { + // reset init image if we deleted it + if ( + getState().generation.initialImage?.imageName === + imageDTO.image_name + ) { + dispatch(clearInitialImage()); + } - if (imagesUsage.some((i) => i.isInitialImage)) { - dispatch(clearInitialImage()); - } + // reset controlNets that use the deleted images + forEach(getState().controlNet.controlNets, (controlNet) => { + if ( + controlNet.controlImage === imageDTO.image_name || + controlNet.processedControlImage === imageDTO.image_name + ) { + dispatch( + controlNetImageChanged({ + controlNetId: controlNet.controlNetId, + controlImage: null, + }) + ); + dispatch( + controlNetProcessedImageChanged({ + controlNetId: controlNet.controlNetId, + processedControlImage: null, + }) + ); + } + }); - if (imagesUsage.some((i) => i.isNodesImage)) { - dispatch(nodeEditorReset()); - } + // reset nodes that use the deleted images + getState().nodes.nodes.forEach((node) => { + if (!isInvocationNode(node)) { + return; + } + + forEach(node.data.inputs, (input) => { + if ( + input.type === 'ImageField' && + input.value?.image_name === imageDTO.image_name + ) { + dispatch( + fieldImageValueChanged({ + nodeId: node.data.id, + fieldName: input.name, + value: undefined, + }) + ); + } + }); + }); + }); } catch { // no-op } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 32a6cce203..739bbd7110 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -5,6 +5,7 @@ import { modelsApi } from 'services/api/endpoints/models'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { startAppListening } from '../..'; +import { size } from 'lodash-es'; export const addSocketConnectedEventListener = () => { startAppListening({ @@ -18,7 +19,7 @@ export const addSocketConnectedEventListener = () => { const { disabledTabs } = config; - if (!nodes.schema && !disabledTabs.includes('nodes')) { + if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/common/components/IAIButton.tsx b/invokeai/frontend/web/src/common/components/IAIButton.tsx index d1e77537cc..4058296aaf 100644 --- a/invokeai/frontend/web/src/common/components/IAIButton.tsx +++ b/invokeai/frontend/web/src/common/components/IAIButton.tsx @@ -8,8 +8,8 @@ import { import { memo, ReactNode } from 'react'; export interface IAIButtonProps extends ButtonProps { - tooltip?: string; - tooltipProps?: Omit; + tooltip?: TooltipProps['label']; + tooltipProps?: Omit; isChecked?: boolean; children: ReactNode; } diff --git a/invokeai/frontend/web/src/common/components/IAICollapse.tsx b/invokeai/frontend/web/src/common/components/IAICollapse.tsx index 0ce767ed9d..a5e08e6ddc 100644 --- a/invokeai/frontend/web/src/common/components/IAICollapse.tsx +++ b/invokeai/frontend/web/src/common/components/IAICollapse.tsx @@ -34,14 +34,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { gap: 2, borderTopRadius: 'base', borderBottomRadius: isOpen ? 0 : 'base', - bg: isOpen - ? mode('base.200', 'base.750')(colorMode) - : mode('base.150', 'base.800')(colorMode), + bg: mode('base.250', 'base.750')(colorMode), color: mode('base.900', 'base.100')(colorMode), _hover: { - bg: isOpen - ? mode('base.250', 'base.700')(colorMode) - : mode('base.200', 'base.750')(colorMode), + bg: mode('base.300', 'base.700')(colorMode), }, fontSize: 'sm', fontWeight: 600, @@ -90,9 +86,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { { const [isHovered, setIsHovered] = useState(false); const handleMouseOver = useCallback( (e: MouseEvent) => { - if (onMouseOver) onMouseOver(e); + if (onMouseOver) { + onMouseOver(e); + } setIsHovered(true); }, [onMouseOver] ); const handleMouseOut = useCallback( (e: MouseEvent) => { - if (onMouseOut) onMouseOut(e); + if (onMouseOut) { + onMouseOut(e); + } setIsHovered(false); }, [onMouseOut] @@ -122,7 +126,7 @@ const IAIDndImage = (props: IAIDndImageProps) => { ? {} : { cursor: 'pointer', - bg: mode('base.200', 'base.800')(colorMode), + bg: mode('base.200', 'base.700')(colorMode), _hover: { bg: mode('base.300', 'base.650')(colorMode), color: mode('base.500', 'base.300')(colorMode), diff --git a/invokeai/frontend/web/src/common/components/IAIErrorLoadingImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIErrorLoadingImageFallback.tsx index 2136acc3c3..0a5d4fb12f 100644 --- a/invokeai/frontend/web/src/common/components/IAIErrorLoadingImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIErrorLoadingImageFallback.tsx @@ -1,4 +1,5 @@ import { Box, Flex, Icon } from '@chakra-ui/react'; +import { memo } from 'react'; import { FaExclamation } from 'react-icons/fa'; const IAIErrorLoadingImageFallback = () => { @@ -39,4 +40,4 @@ const IAIErrorLoadingImageFallback = () => { ); }; -export default IAIErrorLoadingImageFallback; +export default memo(IAIErrorLoadingImageFallback); diff --git a/invokeai/frontend/web/src/common/components/IAIFillSkeleton.tsx b/invokeai/frontend/web/src/common/components/IAIFillSkeleton.tsx index a3c83cb734..8081714432 100644 --- a/invokeai/frontend/web/src/common/components/IAIFillSkeleton.tsx +++ b/invokeai/frontend/web/src/common/components/IAIFillSkeleton.tsx @@ -1,4 +1,5 @@ import { Box, Skeleton } from '@chakra-ui/react'; +import { memo } from 'react'; const IAIFillSkeleton = () => { return ( @@ -27,4 +28,4 @@ const IAIFillSkeleton = () => { ); }; -export default IAIFillSkeleton; +export default memo(IAIFillSkeleton); diff --git a/invokeai/frontend/web/src/common/components/IAIIconButton.tsx b/invokeai/frontend/web/src/common/components/IAIIconButton.tsx index ed1514055e..0a42430689 100644 --- a/invokeai/frontend/web/src/common/components/IAIIconButton.tsx +++ b/invokeai/frontend/web/src/common/components/IAIIconButton.tsx @@ -9,8 +9,8 @@ import { memo } from 'react'; export type IAIIconButtonProps = IconButtonProps & { role?: string; - tooltip?: string; - tooltipProps?: Omit; + tooltip?: TooltipProps['label']; + tooltipProps?: Omit; isChecked?: boolean; }; diff --git a/invokeai/frontend/web/src/common/components/ImageMetadataOverlay.tsx b/invokeai/frontend/web/src/common/components/ImageMetadataOverlay.tsx index 3ef7d8f83e..765dd3c000 100644 --- a/invokeai/frontend/web/src/common/components/ImageMetadataOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/ImageMetadataOverlay.tsx @@ -1,4 +1,5 @@ import { Badge, Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import { ImageDTO } from 'services/api/types'; type ImageMetadataOverlayProps = { @@ -26,4 +27,4 @@ const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => { ); }; -export default ImageMetadataOverlay; +export default memo(ImageMetadataOverlay); diff --git a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx index b2d5ddb2da..5c91a7ceda 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx @@ -1,4 +1,5 @@ import { Box, Flex, Heading } from '@chakra-ui/react'; +import { memo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; type ImageUploadOverlayProps = { @@ -87,4 +88,4 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => { ); }; -export default ImageUploadOverlay; +export default memo(ImageUploadOverlay); diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index c990a9a24e..5f95056a51 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -150,7 +150,9 @@ const ImageUploader = (props: ImageUploaderProps) => { {...getRootProps({ style: {} })} onKeyDown={(e: KeyboardEvent) => { // Bail out if user hits spacebar - do not open the uploader - if (e.key === ' ') return; + if (e.key === ' ') { + return; + } }} > diff --git a/invokeai/frontend/web/src/common/components/SelectImagePlaceholder.tsx b/invokeai/frontend/web/src/common/components/SelectImagePlaceholder.tsx index a19d447755..2db202ddc0 100644 --- a/invokeai/frontend/web/src/common/components/SelectImagePlaceholder.tsx +++ b/invokeai/frontend/web/src/common/components/SelectImagePlaceholder.tsx @@ -1,4 +1,5 @@ import { Flex, Icon } from '@chakra-ui/react'; +import { memo } from 'react'; import { FaImage } from 'react-icons/fa'; const SelectImagePlaceholder = () => { @@ -19,4 +20,4 @@ const SelectImagePlaceholder = () => { ); }; -export default SelectImagePlaceholder; +export default memo(SelectImagePlaceholder); diff --git a/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx b/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx index 9ff6cd341b..aed5e1f083 100644 --- a/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx @@ -1,4 +1,5 @@ import { Box } from '@chakra-ui/react'; +import { memo } from 'react'; type Props = { isSelected: boolean; @@ -18,6 +19,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => { opacity: isSelected ? 1 : 0.7, transitionProperty: 'common', transitionDuration: '0.1s', + pointerEvents: 'none', shadow: isSelected ? isHovered ? 'hoverSelected.light' @@ -39,4 +41,4 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => { ); }; -export default SelectionOverlay; +export default memo(SelectionOverlay); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index f43ec1851f..e06a1106c1 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -2,71 +2,108 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -// import { validateSeedWeights } from 'common/util/seedWeightPairs'; +import { isInvocationNode } from 'features/nodes/types/types'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { forEach } from 'lodash-es'; -import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import { modelsApi } from '../../services/api/endpoints/models'; +import { forEach, map } from 'lodash-es'; +import { getConnectedEdges } from 'reactflow'; -const readinessSelector = createSelector( +const selector = createSelector( [stateSelector, activeTabNameSelector], (state, activeTabName) => { - const { generation, system } = state; - const { initialImage } = generation; + const { generation, system, nodes } = state; + const { initialImage, model } = generation; const { isProcessing, isConnected } = system; - let isReady = true; - const reasonsWhyNotReady: string[] = []; + const reasons: string[] = []; - if (activeTabName === 'img2img' && !initialImage) { - isReady = false; - reasonsWhyNotReady.push('No initial image selected'); - } - - const { isSuccess: mainModelsSuccessfullyLoaded } = - modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state); - if (!mainModelsSuccessfullyLoaded) { - isReady = false; - reasonsWhyNotReady.push('Models are not loaded'); - } - - // TODO: job queue // Cannot generate if already processing an image if (isProcessing) { - isReady = false; - reasonsWhyNotReady.push('System Busy'); + reasons.push('System busy'); } // Cannot generate if not connected if (!isConnected) { - isReady = false; - reasonsWhyNotReady.push('System Disconnected'); + reasons.push('System disconnected'); } - // // Cannot generate variations without valid seed weights - // if ( - // shouldGenerateVariations && - // (!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1) - // ) { - // isReady = false; - // reasonsWhyNotReady.push('Seed-Weights badly formatted.'); - // } + if (activeTabName === 'img2img' && !initialImage) { + reasons.push('No initial image selected'); + } - forEach(state.controlNet.controlNets, (controlNet, id) => { - if (!controlNet.model) { - isReady = false; - reasonsWhyNotReady.push(`ControlNet ${id} has no model selected.`); + if (activeTabName === 'nodes' && nodes.shouldValidateGraph) { + if (!nodes.nodes.length) { + reasons.push('No nodes in graph'); } - }); - // All good - return { isReady, reasonsWhyNotReady }; + nodes.nodes.forEach((node) => { + if (!isInvocationNode(node)) { + return; + } + + const nodeTemplate = nodes.nodeTemplates[node.data.type]; + + if (!nodeTemplate) { + // Node type not found + reasons.push('Missing node template'); + return; + } + + const connectedEdges = getConnectedEdges([node], nodes.edges); + + forEach(node.data.inputs, (field) => { + const fieldTemplate = nodeTemplate.inputs[field.name]; + const hasConnection = connectedEdges.some( + (edge) => + edge.target === node.id && edge.targetHandle === field.name + ); + + if (!fieldTemplate) { + reasons.push('Missing field template'); + return; + } + + if (fieldTemplate.required && !field.value && !hasConnection) { + reasons.push( + `${node.data.label || nodeTemplate.title} -> ${ + field.label || fieldTemplate.title + } missing input` + ); + return; + } + }); + }); + } else { + if (!model) { + reasons.push('No model selected'); + } + + if (state.controlNet.isEnabled) { + map(state.controlNet.controlNets).forEach((controlNet, i) => { + if (!controlNet.isEnabled) { + return; + } + if (!controlNet.model) { + reasons.push(`ControlNet ${i + 1} has no model selected.`); + } + + if ( + !controlNet.controlImage || + (!controlNet.processedControlImage && + controlNet.processorType !== 'none') + ) { + reasons.push(`ControlNet ${i + 1} has no control image`); + } + }); + } + } + + return { isReady: !reasons.length, isProcessing, reasons }; }, defaultSelectorOptions ); export const useIsReadyToInvoke = () => { - const { isReady } = useAppSelector(readinessSelector); - return isReady; + const { isReady, isProcessing, reasons } = useAppSelector(selector); + return { isReady, isProcessing, reasons }; }; diff --git a/invokeai/frontend/web/src/common/hooks/useResolution.ts b/invokeai/frontend/web/src/common/hooks/useResolution.ts index 96b95ee074..fb52555be8 100644 --- a/invokeai/frontend/web/src/common/hooks/useResolution.ts +++ b/invokeai/frontend/web/src/common/hooks/useResolution.ts @@ -11,8 +11,14 @@ export default function useResolution(): const tabletResolutions = ['md', 'lg']; const desktopResolutions = ['xl', '2xl']; - if (mobileResolutions.includes(breakpointValue)) return 'mobile'; - if (tabletResolutions.includes(breakpointValue)) return 'tablet'; - if (desktopResolutions.includes(breakpointValue)) return 'desktop'; + if (mobileResolutions.includes(breakpointValue)) { + return 'mobile'; + } + if (tabletResolutions.includes(breakpointValue)) { + return 'tablet'; + } + if (desktopResolutions.includes(breakpointValue)) { + return 'desktop'; + } return 'unknown'; } diff --git a/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts b/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts new file mode 100644 index 0000000000..e29005186f --- /dev/null +++ b/invokeai/frontend/web/src/common/util/colorTokenToCssVar.ts @@ -0,0 +1,2 @@ +export const colorTokenToCssVar = (colorToken: string) => + `var(--invokeai-colors-${colorToken.split('.').join('-')}`; diff --git a/invokeai/frontend/web/src/common/util/dateComparator.ts b/invokeai/frontend/web/src/common/util/dateComparator.ts index ea0dc28b6d..27af542261 100644 --- a/invokeai/frontend/web/src/common/util/dateComparator.ts +++ b/invokeai/frontend/web/src/common/util/dateComparator.ts @@ -6,7 +6,11 @@ export const dateComparator = (a: string, b: string) => { const dateB = new Date(b); // sort in ascending order - if (dateA > dateB) return 1; - if (dateA < dateB) return -1; + if (dateA > dateB) { + return 1; + } + if (dateA < dateB) { + return -1; + } return 0; }; diff --git a/invokeai/frontend/web/src/common/util/openBase64ImageInTab.ts b/invokeai/frontend/web/src/common/util/openBase64ImageInTab.ts index 0e18ccb45f..71d3bcd661 100644 --- a/invokeai/frontend/web/src/common/util/openBase64ImageInTab.ts +++ b/invokeai/frontend/web/src/common/util/openBase64ImageInTab.ts @@ -5,7 +5,9 @@ type Base64AndCaption = { const openBase64ImageInTab = (images: Base64AndCaption[]) => { const w = window.open(''); - if (!w) return; + if (!w) { + return; + } images.forEach((i) => { const image = new Image(); diff --git a/invokeai/frontend/web/src/features/canvas/components/ClearCanvasHistoryButtonModal.tsx b/invokeai/frontend/web/src/features/canvas/components/ClearCanvasHistoryButtonModal.tsx index 49a13c401c..a86497aade 100644 --- a/invokeai/frontend/web/src/features/canvas/components/ClearCanvasHistoryButtonModal.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/ClearCanvasHistoryButtonModal.tsx @@ -5,6 +5,7 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice'; import { useTranslation } from 'react-i18next'; import { FaTrash } from 'react-icons/fa'; import { isStagingSelector } from '../store/canvasSelectors'; +import { memo } from 'react'; const ClearCanvasHistoryButtonModal = () => { const isStaging = useAppSelector(isStagingSelector); @@ -28,4 +29,4 @@ const ClearCanvasHistoryButtonModal = () => { ); }; -export default ClearCanvasHistoryButtonModal; +export default memo(ClearCanvasHistoryButtonModal); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx index 7a82e64270..4f9e47282d 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx @@ -1,6 +1,6 @@ import { Box, chakra, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { canvasSelector, @@ -9,7 +9,7 @@ import { import Konva from 'konva'; import { KonvaEventObject } from 'konva/lib/Node'; import { Vector2d } from 'konva/lib/types'; -import { useCallback, useRef } from 'react'; +import { memo, useCallback, useEffect, useRef } from 'react'; import { Layer, Stage } from 'react-konva'; import useCanvasDragMove from '../hooks/useCanvasDragMove'; import useCanvasHotkeys from '../hooks/useCanvasHotkeys'; @@ -18,6 +18,7 @@ import useCanvasMouseMove from '../hooks/useCanvasMouseMove'; import useCanvasMouseOut from '../hooks/useCanvasMouseOut'; import useCanvasMouseUp from '../hooks/useCanvasMouseUp'; import useCanvasWheel from '../hooks/useCanvasZoom'; +import { canvasResized } from '../store/canvasSlice'; import { setCanvasBaseLayer, setCanvasStage, @@ -106,7 +107,8 @@ const IAICanvas = () => { shouldAntialias, } = useAppSelector(selector); useCanvasHotkeys(); - + const dispatch = useAppDispatch(); + const containerRef = useRef(null); const stageRef = useRef(null); const canvasBaseLayerRef = useRef(null); @@ -137,8 +139,30 @@ const IAICanvas = () => { const { handleDragStart, handleDragMove, handleDragEnd } = useCanvasDragMove(); + useEffect(() => { + if (!containerRef.current) { + return; + } + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + if (entry.contentBoxSize) { + const { width, height } = entry.contentRect; + dispatch(canvasResized({ width, height })); + } + } + }); + + resizeObserver.observe(containerRef.current); + + return () => { + resizeObserver.disconnect(); + }; + }, [dispatch]); + return ( { borderRadius: 'base', }} > - + { /> - - + + ); }; -export default IAICanvas; +export default memo(IAICanvas); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasBoundingBoxOverlay.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasBoundingBoxOverlay.tsx index e90d2c4d25..22a8848cad 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasBoundingBoxOverlay.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasBoundingBoxOverlay.tsx @@ -4,6 +4,7 @@ import { isEqual } from 'lodash-es'; import { Group, Rect } from 'react-konva'; import { canvasSelector } from '../store/canvasSelectors'; +import { memo } from 'react'; const selector = createSelector( canvasSelector, @@ -67,4 +68,4 @@ const IAICanvasBoundingBoxOverlay = () => { ); }; -export default IAICanvasBoundingBoxOverlay; +export default memo(IAICanvasBoundingBoxOverlay); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasGrid.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasGrid.tsx index 1b97acba71..50a68357fb 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasGrid.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasGrid.tsx @@ -6,7 +6,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { isEqual, range } from 'lodash-es'; -import { ReactNode, useCallback, useLayoutEffect, useState } from 'react'; +import { ReactNode, memo, useCallback, useLayoutEffect, useState } from 'react'; import { Group, Line as KonvaLine } from 'react-konva'; const selector = createSelector( @@ -117,4 +117,4 @@ const IAICanvasGrid = () => { return {gridLines}; }; -export default IAICanvasGrid; +export default memo(IAICanvasGrid); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx index eb41857e46..9f8829c280 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx @@ -4,6 +4,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import useImage from 'use-image'; import { CanvasImage } from '../store/canvasTypes'; import { $authToken } from 'services/api/client'; +import { memo } from 'react'; type IAICanvasImageProps = { canvasImage: CanvasImage; @@ -25,4 +26,4 @@ const IAICanvasImage = (props: IAICanvasImageProps) => { return ; }; -export default IAICanvasImage; +export default memo(IAICanvasImage); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx index ea5e9a6486..b636ef9528 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx @@ -4,7 +4,7 @@ import { systemSelector } from 'features/system/store/systemSelectors'; import { ImageConfig } from 'konva/lib/shapes/Image'; import { isEqual } from 'lodash-es'; -import { useEffect, useState } from 'react'; +import { memo, useEffect, useState } from 'react'; import { Image as KonvaImage } from 'react-konva'; import { canvasSelector } from '../store/canvasSelectors'; @@ -66,4 +66,4 @@ const IAICanvasIntermediateImage = (props: Props) => { ) : null; }; -export default IAICanvasIntermediateImage; +export default memo(IAICanvasIntermediateImage); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskCompositer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskCompositer.tsx index e374d2aa7b..e65f51cade 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskCompositer.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskCompositer.tsx @@ -7,7 +7,7 @@ import { Rect } from 'react-konva'; import { rgbaColorToString } from 'features/canvas/util/colorToString'; import Konva from 'konva'; import { isNumber } from 'lodash-es'; -import { useCallback, useEffect, useRef, useState } from 'react'; +import { memo, useCallback, useEffect, useRef, useState } from 'react'; export const canvasMaskCompositerSelector = createSelector( canvasSelector, @@ -125,7 +125,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => { }, [offset]); useEffect(() => { - if (fillPatternImage) return; + if (fillPatternImage) { + return; + } const image = new Image(); image.onload = () => { @@ -135,7 +137,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => { }, [fillPatternImage, maskColorString]); useEffect(() => { - if (!fillPatternImage) return; + if (!fillPatternImage) { + return; + } fillPatternImage.src = getColoredSVG(maskColorString); }, [fillPatternImage, maskColorString]); @@ -151,8 +155,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => { !isNumber(stageScale) || !isNumber(stageDimensions.width) || !isNumber(stageDimensions.height) - ) + ) { return null; + } return ( { ); }; -export default IAICanvasMaskCompositer; +export default memo(IAICanvasMaskCompositer); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskLines.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskLines.tsx index a553653901..ca91e11350 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskLines.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasMaskLines.tsx @@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es'; import { Group, Line } from 'react-konva'; import { isCanvasMaskLine } from '../store/canvasTypes'; +import { memo } from 'react'; export const canvasLinesSelector = createSelector( [canvasSelector], @@ -52,4 +53,4 @@ const IAICanvasLines = (props: InpaintingCanvasLinesProps) => { ); }; -export default IAICanvasLines; +export default memo(IAICanvasLines); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx index ec1e87cca7..c56dba2b8c 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx @@ -12,6 +12,7 @@ import { isCanvasFillRect, } from '../store/canvasTypes'; import IAICanvasImage from './IAICanvasImage'; +import { memo } from 'react'; const selector = createSelector( [canvasSelector], @@ -33,7 +34,9 @@ const selector = createSelector( const IAICanvasObjectRenderer = () => { const { objects } = useAppSelector(selector); - if (!objects) return null; + if (!objects) { + return null; + } return ( @@ -101,4 +104,4 @@ const IAICanvasObjectRenderer = () => { ); }; -export default IAICanvasObjectRenderer; +export default memo(IAICanvasObjectRenderer); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasResizer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasResizer.tsx deleted file mode 100644 index d16a5dab87..0000000000 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasResizer.tsx +++ /dev/null @@ -1,89 +0,0 @@ -import { Flex, Spinner } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { - canvasSelector, - initialCanvasImageSelector, -} from 'features/canvas/store/canvasSelectors'; -import { - resizeAndScaleCanvas, - resizeCanvas, - setCanvasContainerDimensions, - setDoesCanvasNeedScaling, -} from 'features/canvas/store/canvasSlice'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { useLayoutEffect, useRef } from 'react'; - -const canvasResizerSelector = createSelector( - canvasSelector, - initialCanvasImageSelector, - activeTabNameSelector, - (canvas, initialCanvasImage, activeTabName) => { - const { doesCanvasNeedScaling, isCanvasInitialized } = canvas; - return { - doesCanvasNeedScaling, - activeTabName, - initialCanvasImage, - isCanvasInitialized, - }; - } -); - -const IAICanvasResizer = () => { - const dispatch = useAppDispatch(); - const { - doesCanvasNeedScaling, - activeTabName, - initialCanvasImage, - isCanvasInitialized, - } = useAppSelector(canvasResizerSelector); - - const ref = useRef(null); - - useLayoutEffect(() => { - window.setTimeout(() => { - if (!ref.current) return; - - const { clientWidth, clientHeight } = ref.current; - - dispatch( - setCanvasContainerDimensions({ - width: clientWidth, - height: clientHeight, - }) - ); - - if (!isCanvasInitialized) { - dispatch(resizeAndScaleCanvas()); - } else { - dispatch(resizeCanvas()); - } - - dispatch(setDoesCanvasNeedScaling(false)); - }, 0); - }, [ - dispatch, - initialCanvasImage, - doesCanvasNeedScaling, - activeTabName, - isCanvasInitialized, - ]); - - return ( - - - - ); -}; - -export default IAICanvasResizer; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx index 5355e28762..fa73f020da 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx @@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es'; import { Group, Rect } from 'react-konva'; import IAICanvasImage from './IAICanvasImage'; +import { memo } from 'react'; const selector = createSelector( [canvasSelector], @@ -88,4 +89,4 @@ const IAICanvasStagingArea = (props: Props) => { ); }; -export default IAICanvasStagingArea; +export default memo(IAICanvasStagingArea); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx index 1929bff8f9..cc15141d38 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx @@ -13,7 +13,7 @@ import { } from 'features/canvas/store/canvasSlice'; import { isEqual } from 'lodash-es'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { @@ -129,7 +129,9 @@ const IAICanvasStagingAreaToolbar = () => { currentStagingAreaImage?.imageName ?? skipToken ); - if (!currentStagingAreaImage) return null; + if (!currentStagingAreaImage) { + return null; + } return ( { w="100%" align="center" justify="center" - filter="drop-shadow(0 0.5rem 1rem rgba(0,0,0))" onMouseOver={handleMouseOver} onMouseOut={handleMouseOut} > - + { ); }; -export default IAICanvasStagingAreaToolbar; +export default memo(IAICanvasStagingAreaToolbar); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx index 8c1dfbb86f..7aa9cad003 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx @@ -7,6 +7,7 @@ import { isEqual } from 'lodash-es'; import { useTranslation } from 'react-i18next'; import roundToHundreth from '../util/roundToHundreth'; import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos'; +import { memo } from 'react'; const warningColor = 'var(--invokeai-colors-warning-500)'; @@ -162,4 +163,4 @@ const IAICanvasStatusText = () => { ); }; -export default IAICanvasStatusText; +export default memo(IAICanvasStatusText); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolPreview.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolPreview.tsx index 8ad58e020c..7529ec42a0 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolPreview.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolPreview.tsx @@ -10,6 +10,7 @@ import { COLOR_PICKER_SIZE, COLOR_PICKER_STROKE_RADIUS, } from '../util/constants'; +import { memo } from 'react'; const canvasBrushPreviewSelector = createSelector( canvasSelector, @@ -134,7 +135,9 @@ const IAICanvasToolPreview = (props: GroupConfig) => { clip, } = useAppSelector(canvasBrushPreviewSelector); - if (!shouldDrawBrushPreview) return null; + if (!shouldDrawBrushPreview) { + return null; + } return ( @@ -206,4 +209,4 @@ const IAICanvasToolPreview = (props: GroupConfig) => { ); }; -export default IAICanvasToolPreview; +export default memo(IAICanvasToolPreview); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx index 41c281d259..0f94b1c57a 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx @@ -19,7 +19,7 @@ import { KonvaEventObject } from 'konva/lib/Node'; import { Vector2d } from 'konva/lib/types'; import { isEqual } from 'lodash-es'; -import { useCallback, useEffect, useRef, useState } from 'react'; +import { memo, useCallback, useEffect, useRef, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { Group, Rect, Transformer } from 'react-konva'; @@ -85,7 +85,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => { useState(false); useEffect(() => { - if (!transformerRef.current || !shapeRef.current) return; + if (!transformerRef.current || !shapeRef.current) { + return; + } transformerRef.current.nodes([shapeRef.current]); transformerRef.current.getLayer()?.batchDraw(); }, []); @@ -133,7 +135,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => { * not its width and height. We need to un-scale the width and height before * setting the values. */ - if (!shapeRef.current) return; + if (!shapeRef.current) { + return; + } const rect = shapeRef.current; @@ -313,4 +317,4 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => { ); }; -export default IAICanvasBoundingBox; +export default memo(IAICanvasBoundingBox); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx index 25ef295631..76211a2e95 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx @@ -20,6 +20,7 @@ import { } from 'features/canvas/store/canvasSlice'; import { rgbaColorToString } from 'features/canvas/util/colorToString'; import { isEqual } from 'lodash-es'; +import { memo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -150,4 +151,4 @@ const IAICanvasMaskOptions = () => { ); }; -export default IAICanvasMaskOptions; +export default memo(IAICanvasMaskOptions); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx index ae03df8409..aae2da5632 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx @@ -18,7 +18,7 @@ import { } from 'features/canvas/store/canvasSlice'; import { isEqual } from 'lodash-es'; -import { ChangeEvent } from 'react'; +import { ChangeEvent, memo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { FaWrench } from 'react-icons/fa'; @@ -163,4 +163,4 @@ const IAICanvasSettingsButtonPopover = () => { ); }; -export default IAICanvasSettingsButtonPopover; +export default memo(IAICanvasSettingsButtonPopover); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx index 158e2954af..a3e8f6af8b 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx @@ -18,6 +18,7 @@ import { } from 'features/canvas/store/canvasSlice'; import { systemSelector } from 'features/system/store/systemSelectors'; import { clamp, isEqual } from 'lodash-es'; +import { memo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -252,4 +253,4 @@ const IAICanvasToolChooserOptions = () => { ); }; -export default IAICanvasToolChooserOptions; +export default memo(IAICanvasToolChooserOptions); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolbar.tsx index 26ccfe31b6..49ce63d25f 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolbar.tsx @@ -18,7 +18,6 @@ import { import { resetCanvas, resetCanvasView, - resizeAndScaleCanvas, setIsMaskEnabled, setLayer, setTool, @@ -48,6 +47,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton'; import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover'; import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions'; import IAICanvasUndoButton from './IAICanvasUndoButton'; +import { memo } from 'react'; export const selector = createSelector( [systemSelector, canvasSelector, isStagingSelector], @@ -166,7 +166,9 @@ const IAICanvasToolbar = () => { const handleResetCanvasView = (shouldScaleTo1 = false) => { const canvasBaseLayer = getCanvasBaseLayer(); - if (!canvasBaseLayer) return; + if (!canvasBaseLayer) { + return; + } const clientRect = canvasBaseLayer.getClientRect({ skipTransform: true, }); @@ -180,7 +182,6 @@ const IAICanvasToolbar = () => { const handleResetCanvas = () => { dispatch(resetCanvas()); - dispatch(resizeAndScaleCanvas()); }; const handleMergeVisible = () => { @@ -309,4 +310,4 @@ const IAICanvasToolbar = () => { ); }; -export default IAICanvasToolbar; +export default memo(IAICanvasToolbar); diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasDragMove.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasDragMove.ts index 6861c25842..81e9c0b855 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasDragMove.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasDragMove.ts @@ -32,13 +32,17 @@ const useCanvasDrag = () => { return { handleDragStart: useCallback(() => { - if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return; + if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) { + return; + } dispatch(setIsMovingStage(true)); }, [dispatch, isMovingBoundingBox, isStaging, tool]), handleDragMove: useCallback( (e: KonvaEventObject) => { - if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return; + if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) { + return; + } const newCoordinates = { x: e.target.x(), y: e.target.y() }; @@ -48,7 +52,9 @@ const useCanvasDrag = () => { ), handleDragEnd: useCallback(() => { - if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return; + if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) { + return; + } dispatch(setIsMovingStage(false)); }, [dispatch, isMovingBoundingBox, isStaging, tool]), }; diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasHotkeys.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasHotkeys.ts index 6f4669a42a..1641360e5e 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasHotkeys.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasHotkeys.ts @@ -134,7 +134,9 @@ const useInpaintingCanvasHotkeys = () => { useHotkeys( ['space'], (e: KeyboardEvent) => { - if (e.repeat) return; + if (e.repeat) { + return; + } canvasStage?.container().focus(); diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseDown.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseDown.ts index 67bf7a8539..d98a44edd9 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseDown.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseDown.ts @@ -38,7 +38,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject) => { return useCallback( (e: KonvaEventObject) => { - if (!stageRef.current) return; + if (!stageRef.current) { + return; + } stageRef.current.container().focus(); @@ -54,7 +56,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject) => { const scaledCursorPosition = getScaledCursorPosition(stageRef.current); - if (!scaledCursorPosition) return; + if (!scaledCursorPosition) { + return; + } e.evt.preventDefault(); diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseMove.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseMove.ts index abeab825e4..088356006e 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseMove.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseMove.ts @@ -41,11 +41,15 @@ const useCanvasMouseMove = ( const { updateColorUnderCursor } = useColorPicker(); return useCallback(() => { - if (!stageRef.current) return; + if (!stageRef.current) { + return; + } const scaledCursorPosition = getScaledCursorPosition(stageRef.current); - if (!scaledCursorPosition) return; + if (!scaledCursorPosition) { + return; + } dispatch(setCursorPosition(scaledCursorPosition)); @@ -56,7 +60,9 @@ const useCanvasMouseMove = ( return; } - if (!isDrawing || tool === 'move' || isStaging) return; + if (!isDrawing || tool === 'move' || isStaging) { + return; + } didMouseMoveRef.current = true; dispatch( diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseUp.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseUp.ts index 8e70543c6f..d99d63c223 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseUp.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasMouseUp.ts @@ -47,7 +47,9 @@ const useCanvasMouseUp = ( if (!didMouseMoveRef.current && isDrawing && stageRef.current) { const scaledCursorPosition = getScaledCursorPosition(stageRef.current); - if (!scaledCursorPosition) return; + if (!scaledCursorPosition) { + return; + } /** * Extend the current line. diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasZoom.ts b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasZoom.ts index 3d6a1d7804..f58211ca2c 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useCanvasZoom.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useCanvasZoom.ts @@ -35,13 +35,17 @@ const useCanvasWheel = (stageRef: MutableRefObject) => { return useCallback( (e: KonvaEventObject) => { // stop default scrolling - if (!stageRef.current || isMoveStageKeyHeld) return; + if (!stageRef.current || isMoveStageKeyHeld) { + return; + } e.evt.preventDefault(); const cursorPos = stageRef.current.getPointerPosition(); - if (!cursorPos) return; + if (!cursorPos) { + return; + } const mousePointTo = { x: (cursorPos.x - stageRef.current.x()) / stageScale, diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts b/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts index 64289a1fd3..0ade036987 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts @@ -16,11 +16,15 @@ const useColorPicker = () => { return { updateColorUnderCursor: () => { - if (!stage || !canvasBaseLayer) return; + if (!stage || !canvasBaseLayer) { + return; + } const position = stage.getPointerPosition(); - if (!position) return; + if (!position) { + return; + } const pixelRatio = Konva.pixelRatio; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasPersistDenylist.ts b/invokeai/frontend/web/src/features/canvas/store/canvasPersistDenylist.ts index dc0df55ad0..1990f28516 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasPersistDenylist.ts @@ -3,8 +3,4 @@ import { CanvasState } from './canvasTypes'; /** * Canvas slice persist denylist */ -export const canvasPersistDenylist: (keyof CanvasState)[] = [ - 'cursorPosition', - 'isCanvasInitialized', - 'doesCanvasNeedScaling', -]; +export const canvasPersistDenylist: (keyof CanvasState)[] = ['cursorPosition']; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 11f829221a..ca26a0567f 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -5,10 +5,6 @@ import { roundToMultiple, } from 'common/util/roundDownToMultiple'; import { setAspectRatio } from 'features/parameters/store/generationSlice'; -import { - setActiveTab, - setShouldUseCanvasBetaLayout, -} from 'features/ui/store/uiSlice'; import { IRect, Vector2d } from 'konva/lib/types'; import { clamp, cloneDeep } from 'lodash-es'; import { RgbaColor } from 'react-colorful'; @@ -50,12 +46,9 @@ export const initialCanvasState: CanvasState = { boundingBoxScaleMethod: 'none', brushColor: { r: 90, g: 90, b: 255, a: 1 }, brushSize: 50, - canvasContainerDimensions: { width: 0, height: 0 }, colorPickerColor: { r: 90, g: 90, b: 255, a: 1 }, cursorPosition: null, - doesCanvasNeedScaling: false, futureLayerStates: [], - isCanvasInitialized: false, isDrawing: false, isMaskEnabled: true, isMouseOverBoundingBox: false, @@ -208,7 +201,6 @@ export const canvasSlice = createSlice({ }; state.futureLayerStates = []; - state.isCanvasInitialized = false; const newScale = calculateScale( stageDimensions.width, stageDimensions.height, @@ -228,7 +220,6 @@ export const canvasSlice = createSlice({ ); state.stageScale = newScale; state.stageCoordinates = newCoordinates; - state.doesCanvasNeedScaling = true; }, setBoundingBoxDimensions: (state, action: PayloadAction) => { const newDimensions = roundDimensionsTo64(action.payload); @@ -258,9 +249,6 @@ export const canvasSlice = createSlice({ setBoundingBoxPreviewFill: (state, action: PayloadAction) => { state.boundingBoxPreviewFill = action.payload; }, - setDoesCanvasNeedScaling: (state, action: PayloadAction) => { - state.doesCanvasNeedScaling = action.payload; - }, setStageScale: (state, action: PayloadAction) => { state.stageScale = action.payload; }, @@ -397,7 +385,9 @@ export const canvasSlice = createSlice({ const { tool, layer, brushColor, brushSize, shouldRestrictStrokesToBox } = state; - if (tool === 'move' || tool === 'colorPicker') return; + if (tool === 'move' || tool === 'colorPicker') { + return; + } const newStrokeWidth = brushSize / 2; @@ -434,14 +424,18 @@ export const canvasSlice = createSlice({ addPointToCurrentLine: (state, action: PayloadAction) => { const lastLine = state.layerState.objects.findLast(isCanvasAnyLine); - if (!lastLine) return; + if (!lastLine) { + return; + } lastLine.points.push(...action.payload); }, undo: (state) => { const targetState = state.pastLayerStates.pop(); - if (!targetState) return; + if (!targetState) { + return; + } state.futureLayerStates.unshift(cloneDeep(state.layerState)); @@ -454,7 +448,9 @@ export const canvasSlice = createSlice({ redo: (state) => { const targetState = state.futureLayerStates.shift(); - if (!targetState) return; + if (!targetState) { + return; + } state.pastLayerStates.push(cloneDeep(state.layerState)); @@ -485,97 +481,14 @@ export const canvasSlice = createSlice({ state.layerState = initialLayerState; state.futureLayerStates = []; }, - setCanvasContainerDimensions: ( + canvasResized: ( state, - action: PayloadAction + action: PayloadAction<{ width: number; height: number }> ) => { - state.canvasContainerDimensions = action.payload; - }, - resizeAndScaleCanvas: (state) => { - const { width: containerWidth, height: containerHeight } = - state.canvasContainerDimensions; - - const initialCanvasImage = - state.layerState.objects.find(isCanvasBaseImage); - + const { width, height } = action.payload; const newStageDimensions = { - width: Math.floor(containerWidth), - height: Math.floor(containerHeight), - }; - - if (!initialCanvasImage) { - const newScale = calculateScale( - newStageDimensions.width, - newStageDimensions.height, - 512, - 512, - STAGE_PADDING_PERCENTAGE - ); - - const newCoordinates = calculateCoordinates( - newStageDimensions.width, - newStageDimensions.height, - 0, - 0, - 512, - 512, - newScale - ); - - const newBoundingBoxDimensions = { width: 512, height: 512 }; - - state.stageScale = newScale; - state.stageCoordinates = newCoordinates; - state.stageDimensions = newStageDimensions; - state.boundingBoxCoordinates = { x: 0, y: 0 }; - state.boundingBoxDimensions = newBoundingBoxDimensions; - - if (state.boundingBoxScaleMethod === 'auto') { - const scaledDimensions = getScaledBoundingBoxDimensions( - newBoundingBoxDimensions - ); - state.scaledBoundingBoxDimensions = scaledDimensions; - } - - return; - } - - const { width: imageWidth, height: imageHeight } = initialCanvasImage; - - const padding = 0.95; - - const newScale = calculateScale( - containerWidth, - containerHeight, - imageWidth, - imageHeight, - padding - ); - - const newCoordinates = calculateCoordinates( - newStageDimensions.width, - newStageDimensions.height, - 0, - 0, - imageWidth, - imageHeight, - newScale - ); - - state.minimumStageScale = newScale; - state.stageScale = newScale; - state.stageCoordinates = floorCoordinates(newCoordinates); - state.stageDimensions = newStageDimensions; - - state.isCanvasInitialized = true; - }, - resizeCanvas: (state) => { - const { width: containerWidth, height: containerHeight } = - state.canvasContainerDimensions; - - const newStageDimensions = { - width: Math.floor(containerWidth), - height: Math.floor(containerHeight), + width: Math.floor(width), + height: Math.floor(height), }; state.stageDimensions = newStageDimensions; @@ -868,14 +781,6 @@ export const canvasSlice = createSlice({ state.layerState.stagingArea = initialLayerState.stagingArea; } }); - - builder.addCase(setShouldUseCanvasBetaLayout, (state) => { - state.doesCanvasNeedScaling = true; - }); - - builder.addCase(setActiveTab, (state) => { - state.doesCanvasNeedScaling = true; - }); builder.addCase(setAspectRatio, (state, action) => { const ratio = action.payload; if (ratio) { @@ -907,8 +812,6 @@ export const { resetCanvas, resetCanvasInteractionState, resetCanvasView, - resizeAndScaleCanvas, - resizeCanvas, setBoundingBoxCoordinates, setBoundingBoxDimensions, setBoundingBoxPreviewFill, @@ -916,10 +819,8 @@ export const { flipBoundingBoxAxes, setBrushColor, setBrushSize, - setCanvasContainerDimensions, setColorPickerColor, setCursorPosition, - setDoesCanvasNeedScaling, setInitialCanvasImage, setIsDrawing, setIsMaskEnabled, @@ -958,6 +859,7 @@ export const { stagingAreaInitialized, canvasSessionIdChanged, setShouldAntialias, + canvasResized, } = canvasSlice.actions; export default canvasSlice.reducer; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index f2ba90b050..1b4eca329d 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -126,12 +126,9 @@ export interface CanvasState { boundingBoxScaleMethod: BoundingBoxScale; brushColor: RgbaColor; brushSize: number; - canvasContainerDimensions: Dimensions; colorPickerColor: RgbaColor; cursorPosition: Vector2d | null; - doesCanvasNeedScaling: boolean; futureLayerStates: CanvasLayerState[]; - isCanvasInitialized: boolean; isDrawing: boolean; isMaskEnabled: boolean; isMouseOverBoundingBox: boolean; diff --git a/invokeai/frontend/web/src/features/canvas/store/thunks/requestCanvasScale.ts b/invokeai/frontend/web/src/features/canvas/store/thunks/requestCanvasScale.ts deleted file mode 100644 index f16c92651a..0000000000 --- a/invokeai/frontend/web/src/features/canvas/store/thunks/requestCanvasScale.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { AppDispatch, AppGetState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { debounce } from 'lodash-es'; -import { setDoesCanvasNeedScaling } from '../canvasSlice'; - -const debouncedCanvasScale = debounce((dispatch: AppDispatch) => { - dispatch(setDoesCanvasNeedScaling(true)); -}, 300); - -export const requestCanvasRescale = - () => (dispatch: AppDispatch, getState: AppGetState) => { - const activeTabName = activeTabNameSelector(getState()); - if (activeTabName === 'unifiedCanvas') { - debouncedCanvasScale(dispatch); - } - }; diff --git a/invokeai/frontend/web/src/features/canvas/util/getScaledCursorPosition.ts b/invokeai/frontend/web/src/features/canvas/util/getScaledCursorPosition.ts index 03a4d749bf..4cfd7dc8f1 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getScaledCursorPosition.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getScaledCursorPosition.ts @@ -5,7 +5,9 @@ const getScaledCursorPosition = (stage: Stage) => { const stageTransform = stage.getAbsoluteTransform().copy(); - if (!pointerPosition || !stageTransform) return; + if (!pointerPosition || !stageTransform) { + return; + } const scaledCursorPosition = stageTransform.invert().point(pointerPosition); diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index 3252207edc..de9995c577 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -80,19 +80,19 @@ const ControlNet = (props: ControlNetProps) => { sx={{ flexDir: 'column', gap: 3, - p: 3, + p: 2, borderRadius: 'base', position: 'relative', - bg: 'base.200', + bg: 'base.250', _dark: { - bg: 'base.850', + bg: 'base.750', }, }} > @@ -194,7 +194,7 @@ const ControlNet = (props: ControlNetProps) => { aspectRatio: '1/1', }} > - + )} @@ -207,7 +207,7 @@ const ControlNet = (props: ControlNetProps) => { {isExpanded && ( <> - + diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index 0683282811..3b92d9d0c6 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -1,14 +1,14 @@ -import { Box, Flex, Spinner, SystemStyleObject } from '@chakra-ui/react'; +import { Box, Flex, Spinner } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'features/dnd/types'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import { memo, useCallback, useMemo, useState } from 'react'; import { FaUndo } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; @@ -21,7 +21,7 @@ import { type Props = { controlNet: ControlNetConfig; - height: SystemStyleObject['h']; + isSmall?: boolean; }; const selector = createSelector( @@ -36,15 +36,14 @@ const selector = createSelector( defaultSelectorOptions ); -const ControlNetImagePreview = (props: Props) => { - const { height } = props; +const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => { const { controlImage: controlImageName, processedControlImage: processedControlImageName, processorType, isEnabled, controlNetId, - } = props.controlNet; + } = controlNet; const dispatch = useAppDispatch(); @@ -109,7 +108,7 @@ const ControlNetImagePreview = (props: Props) => { sx={{ position: 'relative', w: 'full', - h: height, + h: isSmall ? 28 : 366, // magic no touch alignItems: 'center', justifyContent: 'center', pointerEvents: isEnabled ? 'auto' : 'none', diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx index 8eed90ce16..97a54dc7d1 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx @@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; const selector = createSelector( stateSelector, @@ -36,4 +36,4 @@ const ParamControlNetFeatureToggle = () => { ); }; -export default ParamControlNetFeatureToggle; +export default memo(ParamControlNetFeatureToggle); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index c08283e1f9..6725c47bb8 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -23,7 +23,7 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { return ( { ); }; -export default ParamDynamicPromptsCollapse; +export default memo(ParamDynamicPromptsCollapse); diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx index 809ec0df10..c028a5d55c 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { combinatorialToggled } from '../store/dynamicPromptsSlice'; const selector = createSelector( @@ -34,4 +34,4 @@ const ParamDynamicPromptsCombinatorial = () => { ); }; -export default ParamDynamicPromptsCombinatorial; +export default memo(ParamDynamicPromptsCombinatorial); diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsEnabled.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsEnabled.tsx index f92fa410f2..1b31147937 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsEnabled.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsEnabled.tsx @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { isEnabledToggled } from '../store/dynamicPromptsSlice'; const selector = createSelector( @@ -33,4 +33,4 @@ const ParamDynamicPromptsToggle = () => { ); }; -export default ParamDynamicPromptsToggle; +export default memo(ParamDynamicPromptsToggle); diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index 5bee317d22..f374f1cb15 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { maxPromptsChanged, maxPromptsReset, @@ -60,4 +60,4 @@ const ParamDynamicPromptsMaxPrompts = () => { ); }; -export default ParamDynamicPromptsMaxPrompts; +export default memo(ParamDynamicPromptsMaxPrompts); diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx index 4eb9a67de2..93daaf946f 100644 --- a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx +++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx @@ -13,7 +13,7 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { forEach } from 'lodash-es'; -import { PropsWithChildren, useCallback, useMemo, useRef } from 'react'; +import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; @@ -118,7 +118,7 @@ const ParamEmbeddingPopover = (props: Props) => { { ); }; -export default ParamEmbeddingPopover; +export default memo(ParamEmbeddingPopover); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AutoAddIcon.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AutoAddIcon.tsx index ffdde04ef5..4e748d61e8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/AutoAddIcon.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AutoAddIcon.tsx @@ -1,4 +1,5 @@ import { Badge, Flex } from '@chakra-ui/react'; +import { memo } from 'react'; const AutoAddIcon = () => { return ( @@ -20,4 +21,4 @@ const AutoAddIcon = () => { ); }; -export default AutoAddIcon; +export default memo(AutoAddIcon); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx index 96d17b548e..be19589f9b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx @@ -6,7 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice'; -import { useCallback, useRef } from 'react'; +import { memo, useCallback, useRef } from 'react'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; const selector = createSelector( @@ -66,7 +66,7 @@ const BoardAutoAddSelect = () => { label="Auto-Add Board" inputRef={inputRef} autoFocus - placeholder={'Select a Board'} + placeholder="Select a Board" value={autoAddBoardId} data={boards} nothingFound="No matching Boards" @@ -81,4 +81,4 @@ const BoardAutoAddSelect = () => { ); }; -export default BoardAutoAddSelect; +export default memo(BoardAutoAddSelect); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 0667c05435..6a012030e8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -2,8 +2,12 @@ import { MenuGroup, MenuItem, MenuList } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; +import { + IAIContextMenu, + IAIContextMenuProps, +} from 'common/components/IAIContextMenu'; import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice'; +import { BoardId } from 'features/gallery/store/types'; import { MouseEvent, memo, useCallback, useMemo } from 'react'; import { FaPlus } from 'react-icons/fa'; import { useBoardName } from 'services/api/hooks/useBoardName'; @@ -11,80 +15,80 @@ import { BoardDTO } from 'services/api/types'; import { menuListMotionProps } from 'theme/components/menu'; import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems'; import NoBoardContextMenuItems from './NoBoardContextMenuItems'; -import { BoardId } from 'features/gallery/store/types'; type Props = { board?: BoardDTO; board_id: BoardId; - children: ContextMenuProps['children']; + children: IAIContextMenuProps['children']; setBoardToDelete?: (board?: BoardDTO) => void; }; -const BoardContextMenu = memo( - ({ board, board_id, setBoardToDelete, children }: Props) => { - const dispatch = useAppDispatch(); +const BoardContextMenu = ({ + board, + board_id, + setBoardToDelete, + children, +}: Props) => { + const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector(stateSelector, ({ gallery, system }) => { - const isAutoAdd = gallery.autoAddBoardId === board_id; - const isProcessing = system.isProcessing; - const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; - return { isAutoAdd, isProcessing, autoAssignBoardOnClick }; - }), - [board_id] - ); + const selector = useMemo( + () => + createSelector(stateSelector, ({ gallery, system }) => { + const isAutoAdd = gallery.autoAddBoardId === board_id; + const isProcessing = system.isProcessing; + const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; + return { isAutoAdd, isProcessing, autoAssignBoardOnClick }; + }), + [board_id] + ); - const { isAutoAdd, isProcessing, autoAssignBoardOnClick } = - useAppSelector(selector); - const boardName = useBoardName(board_id); + const { isAutoAdd, isProcessing, autoAssignBoardOnClick } = + useAppSelector(selector); + const boardName = useBoardName(board_id); - const handleSetAutoAdd = useCallback(() => { - dispatch(autoAddBoardIdChanged(board_id)); - }, [board_id, dispatch]); + const handleSetAutoAdd = useCallback(() => { + dispatch(autoAddBoardIdChanged(board_id)); + }, [board_id, dispatch]); - const skipEvent = useCallback((e: MouseEvent) => { - e.preventDefault(); - }, []); + const skipEvent = useCallback((e: MouseEvent) => { + e.preventDefault(); + }, []); - return ( - - menuProps={{ size: 'sm', isLazy: true }} - menuButtonProps={{ - bg: 'transparent', - _hover: { bg: 'transparent' }, - }} - renderMenu={() => ( - - - } - isDisabled={isAutoAdd || isProcessing || autoAssignBoardOnClick} - onClick={handleSetAutoAdd} - > - Auto-add to this Board - - {!board && } - {board && ( - - )} - - - )} - > - {children} - - ); - } -); + return ( + + menuProps={{ size: 'sm', isLazy: true }} + menuButtonProps={{ + bg: 'transparent', + _hover: { bg: 'transparent' }, + }} + renderMenu={() => ( + + + } + isDisabled={isAutoAdd || isProcessing || autoAssignBoardOnClick} + onClick={handleSetAutoAdd} + > + Auto-add to this Board + + {!board && } + {board && ( + + )} + + + )} + > + {children} + + ); +}; -BoardContextMenu.displayName = 'HoverableBoard'; - -export default BoardContextMenu; +export default memo(BoardContextMenu); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/AddBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/AddBoardButton.tsx index 7a07680878..ebd08e94d5 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/AddBoardButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/AddBoardButton.tsx @@ -1,5 +1,5 @@ import IAIIconButton from 'common/components/IAIIconButton'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { FaPlus } from 'react-icons/fa'; import { useCreateBoardMutation } from 'services/api/endpoints/boards'; @@ -24,4 +24,4 @@ const AddBoardButton = () => { ); }; -export default AddBoardButton; +export default memo(AddBoardButton); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx index cb3474f6bd..4bbd9533fa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx @@ -41,7 +41,7 @@ const BoardsList = (props: Props) => { <> void; } -const GalleryBoard = memo( - ({ board, isSelected, setBoardToDelete }: GalleryBoardProps) => { - const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ gallery, system }) => { - const isSelectedForAutoAdd = - board.board_id === gallery.autoAddBoardId; - const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; - const isProcessing = system.isProcessing; +const GalleryBoard = ({ + board, + isSelected, + setBoardToDelete, +}: GalleryBoardProps) => { + const dispatch = useAppDispatch(); + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ gallery, system }) => { + const isSelectedForAutoAdd = + board.board_id === gallery.autoAddBoardId; + const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; + const isProcessing = system.isProcessing; - return { - isSelectedForAutoAdd, - autoAssignBoardOnClick, - isProcessing, - }; - }, - defaultSelectorOptions - ), - [board.board_id] - ); + return { + isSelectedForAutoAdd, + autoAssignBoardOnClick, + isProcessing, + }; + }, + defaultSelectorOptions + ), + [board.board_id] + ); - const { isSelectedForAutoAdd, autoAssignBoardOnClick, isProcessing } = - useAppSelector(selector); - const [isHovered, setIsHovered] = useState(false); - const handleMouseOver = useCallback(() => { - setIsHovered(true); - }, []); - const handleMouseOut = useCallback(() => { - setIsHovered(false); - }, []); + const { isSelectedForAutoAdd, autoAssignBoardOnClick, isProcessing } = + useAppSelector(selector); + const [isHovered, setIsHovered] = useState(false); + const handleMouseOver = useCallback(() => { + setIsHovered(true); + }, []); + const handleMouseOut = useCallback(() => { + setIsHovered(false); + }, []); - const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id); - const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id); - const tooltip = useMemo(() => { - if (!imagesTotal || !assetsTotal) { - return undefined; + const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id); + const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id); + const tooltip = useMemo(() => { + if (!imagesTotal || !assetsTotal) { + return undefined; + } + return `${imagesTotal} image${ + imagesTotal > 1 ? 's' : '' + }, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`; + }, [assetsTotal, imagesTotal]); + + const { currentData: coverImage } = useGetImageDTOQuery( + board.cover_image_name ?? skipToken + ); + + const { board_name, board_id } = board; + const [localBoardName, setLocalBoardName] = useState(board_name); + + const handleSelectBoard = useCallback(() => { + dispatch(boardIdSelected(board_id)); + if (autoAssignBoardOnClick && !isProcessing) { + dispatch(autoAddBoardIdChanged(board_id)); + } + }, [board_id, autoAssignBoardOnClick, isProcessing, dispatch]); + + const [updateBoard, { isLoading: isUpdateBoardLoading }] = + useUpdateBoardMutation(); + + const droppableData: AddToBoardDropData = useMemo( + () => ({ + id: board_id, + actionType: 'ADD_TO_BOARD', + context: { boardId: board_id }, + }), + [board_id] + ); + + const handleSubmit = useCallback( + async (newBoardName: string) => { + // empty strings are not allowed + if (!newBoardName.trim()) { + setLocalBoardName(board_name); + return; } - return `${imagesTotal} image${ - imagesTotal > 1 ? 's' : '' - }, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`; - }, [assetsTotal, imagesTotal]); - const { currentData: coverImage } = useGetImageDTOQuery( - board.cover_image_name ?? skipToken - ); - - const { board_name, board_id } = board; - const [localBoardName, setLocalBoardName] = useState(board_name); - - const handleSelectBoard = useCallback(() => { - dispatch(boardIdSelected(board_id)); - if (autoAssignBoardOnClick && !isProcessing) { - dispatch(autoAddBoardIdChanged(board_id)); + // don't updated the board name if it hasn't changed + if (newBoardName === board_name) { + return; } - }, [board_id, autoAssignBoardOnClick, isProcessing, dispatch]); - const [updateBoard, { isLoading: isUpdateBoardLoading }] = - useUpdateBoardMutation(); + try { + const { board_name } = await updateBoard({ + board_id, + changes: { board_name: newBoardName }, + }).unwrap(); - const droppableData: AddToBoardDropData = useMemo( - () => ({ - id: board_id, - actionType: 'ADD_TO_BOARD', - context: { boardId: board_id }, - }), - [board_id] - ); + // update local state + setLocalBoardName(board_name); + } catch { + // revert on error + setLocalBoardName(board_name); + } + }, + [board_id, board_name, updateBoard] + ); - const handleSubmit = useCallback( - async (newBoardName: string) => { - // empty strings are not allowed - if (!newBoardName.trim()) { - setLocalBoardName(board_name); - return; - } + const handleChange = useCallback((newBoardName: string) => { + setLocalBoardName(newBoardName); + }, []); - // don't updated the board name if it hasn't changed - if (newBoardName === board_name) { - return; - } - - try { - const { board_name } = await updateBoard({ - board_id, - changes: { board_name: newBoardName }, - }).unwrap(); - - // update local state - setLocalBoardName(board_name); - } catch { - // revert on error - setLocalBoardName(board_name); - } - }, - [board_id, board_name, updateBoard] - ); - - const handleChange = useCallback((newBoardName: string) => { - setLocalBoardName(newBoardName); - }, []); - - return ( - + - - - {(ref) => ( - - - {coverImage?.thumbnail_url ? ( - ( + + + {coverImage?.thumbnail_url ? ( + + ) : ( + + - ) : ( - - - - )} - {/* + )} + {/* */} - {isSelectedForAutoAdd && } - - } + + + - - + - - - - - Move} - /> + // get rid of the edit border + boxShadow: 'none', + }, + }} + /> + - - )} - - - - ); - } -); -GalleryBoard.displayName = 'HoverableBoard'; + Move} + /> + + + )} + + + + ); +}; -export default GalleryBoard; +export default memo(GalleryBoard); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx index 1698a81ac0..7a95e7fcd9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx @@ -3,7 +3,7 @@ import IAIDroppable from 'common/components/IAIDroppable'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { TypesafeDroppableData } from 'features/dnd/types'; import { BoardId } from 'features/gallery/store/types'; -import { ReactNode } from 'react'; +import { ReactNode, memo } from 'react'; import BoardContextMenu from '../BoardContextMenu'; type GenericBoardProps = { @@ -105,4 +105,4 @@ const GenericBoard = (props: GenericBoardProps) => { ); }; -export default GenericBoard; +export default memo(GenericBoard); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index fec280db0f..da51a5fe39 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -156,4 +156,4 @@ const NoBoardBoard = memo(({ isSelected }: Props) => { NoBoardBoard.displayName = 'HoverableBoard'; -export default NoBoardBoard; +export default memo(NoBoardBoard); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx index d62027769b..0212376507 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx @@ -26,7 +26,7 @@ import { setShouldShowImageDetails, setShouldShowProgressInViewer, } from 'features/ui/store/uiSlice'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { @@ -323,4 +323,4 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { ); }; -export default CurrentImageButtons; +export default memo(CurrentImageButtons); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageDisplay.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageDisplay.tsx index 1d8863f4d8..1c342d093e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageDisplay.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageDisplay.tsx @@ -2,6 +2,7 @@ import { Flex } from '@chakra-ui/react'; import CurrentImageButtons from './CurrentImageButtons'; import CurrentImagePreview from './CurrentImagePreview'; +import { memo } from 'react'; const CurrentImageDisplay = () => { return ( @@ -22,4 +23,4 @@ const CurrentImageDisplay = () => { ); }; -export default CurrentImageDisplay; +export default memo(CurrentImageDisplay); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageHidden.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageHidden.tsx index 062cdd7c00..af2a7c5f98 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageHidden.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageHidden.tsx @@ -1,4 +1,5 @@ import { Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import { FaEyeSlash } from 'react-icons/fa'; const CurrentImageHidden = () => { @@ -18,4 +19,4 @@ const CurrentImageHidden = () => { ); }; -export default CurrentImageHidden; +export default memo(CurrentImageHidden); diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx deleted file mode 100644 index 1bbec03f3e..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryPanel.tsx +++ /dev/null @@ -1,119 +0,0 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { gallerySelector } from 'features/gallery/store/gallerySelectors'; -import { setGalleryImageMinimumWidth } from 'features/gallery/store/gallerySlice'; - -import { clamp, isEqual } from 'lodash-es'; -import { useHotkeys } from 'react-hotkeys-hook'; - -import { createSelector } from '@reduxjs/toolkit'; -import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer'; -import { - activeTabNameSelector, - uiSelector, -} from 'features/ui/store/uiSelectors'; -import { setShouldShowGallery } from 'features/ui/store/uiSlice'; -import { memo } from 'react'; -import ImageGalleryContent from './ImageGalleryContent'; - -const selector = createSelector( - [activeTabNameSelector, uiSelector, gallerySelector, isStagingSelector], - (activeTabName, ui, gallery, isStaging) => { - const { shouldPinGallery, shouldShowGallery } = ui; - const { galleryImageMinimumWidth } = gallery; - - return { - activeTabName, - isStaging, - shouldPinGallery, - shouldShowGallery, - galleryImageMinimumWidth, - isResizable: activeTabName !== 'unifiedCanvas', - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -const GalleryDrawer = () => { - const dispatch = useAppDispatch(); - const { - shouldPinGallery, - shouldShowGallery, - galleryImageMinimumWidth, - // activeTabName, - // isStaging, - // isResizable, - } = useAppSelector(selector); - - const handleCloseGallery = () => { - dispatch(setShouldShowGallery(false)); - shouldPinGallery && dispatch(requestCanvasRescale()); - }; - - useHotkeys( - 'esc', - () => { - dispatch(setShouldShowGallery(false)); - }, - { - enabled: () => !shouldPinGallery, - preventDefault: true, - }, - [shouldPinGallery] - ); - - const IMAGE_SIZE_STEP = 32; - - useHotkeys( - 'shift+up', - () => { - if (galleryImageMinimumWidth < 256) { - const newMinWidth = clamp( - galleryImageMinimumWidth + IMAGE_SIZE_STEP, - 32, - 256 - ); - dispatch(setGalleryImageMinimumWidth(newMinWidth)); - } - }, - [galleryImageMinimumWidth] - ); - - useHotkeys( - 'shift+down', - () => { - if (galleryImageMinimumWidth > 32) { - const newMinWidth = clamp( - galleryImageMinimumWidth - IMAGE_SIZE_STEP, - 32, - 256 - ); - dispatch(setGalleryImageMinimumWidth(newMinWidth)); - } - }, - [galleryImageMinimumWidth] - ); - - if (shouldPinGallery) { - return null; - } - - return ( - - - - ); -}; - -export default memo(GalleryDrawer); diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryPinButton.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryPinButton.tsx deleted file mode 100644 index 916dec69a2..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryPinButton.tsx +++ /dev/null @@ -1,44 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; -import { useTranslation } from 'react-i18next'; -import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; - -const selector = createSelector( - [stateSelector], - (state) => { - const { shouldPinGallery } = state.ui; - - return { - shouldPinGallery, - }; - }, - defaultSelectorOptions -); - -const GalleryPinButton = () => { - const dispatch = useAppDispatch(); - const { t } = useTranslation(); - - const { shouldPinGallery } = useAppSelector(selector); - - const handleSetShouldPinGallery = () => { - dispatch(togglePinGalleryPanel()); - dispatch(requestCanvasRescale()); - }; - return ( - : } - /> - ); -}; - -export default GalleryPinButton; diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx index 23cfdcc5fd..2eab78d118 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx @@ -12,7 +12,7 @@ import { setGalleryImageMinimumWidth, shouldAutoSwitchChanged, } from 'features/gallery/store/gallerySlice'; -import { ChangeEvent, useCallback } from 'react'; +import { ChangeEvent, memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { FaWrench } from 'react-icons/fa'; import BoardAutoAddSelect from './Boards/BoardAutoAddSelect'; @@ -101,4 +101,4 @@ const GallerySettingsPopover = () => { ); }; -export default GallerySettingsPopover; +export default memo(GallerySettingsPopover); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx index 0f36273122..bf2b344b4c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx @@ -5,7 +5,7 @@ import { isModalOpenChanged, } from 'features/changeBoardModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; -import { useCallback, useMemo } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { FaFolder, FaTrash } from 'react-icons/fa'; import { MdStar, MdStarBorder } from 'react-icons/md'; import { @@ -74,4 +74,4 @@ const MultipleSelectionMenuItems = () => { ); }; -export default MultipleSelectionMenuItems; +export default memo(MultipleSelectionMenuItems); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index ef6e2ccd5c..e57c8b9797 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -2,10 +2,7 @@ import { MenuItem } from '@chakra-ui/react'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; -import { - resizeAndScaleCanvas, - setInitialCanvasImage, -} from 'features/canvas/store/canvasSlice'; +import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { imagesToChangeSelected, isModalOpenChanged, @@ -29,6 +26,7 @@ import { FaShare, FaTrash, } from 'react-icons/fa'; +import { MdStar, MdStarBorder } from 'react-icons/md'; import { useGetImageMetadataQuery, useStarImagesMutation, @@ -37,7 +35,6 @@ import { import { ImageDTO } from 'services/api/types'; import { useDebounce } from 'use-debounce'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; -import { MdStar, MdStarBorder } from 'react-icons/md'; type SingleSelectionMenuItemsProps = { imageDTO: ImageDTO; @@ -110,7 +107,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const handleSendToCanvas = useCallback(() => { dispatch(sentImageToCanvas()); dispatch(setInitialCanvasImage(imageDTO)); - dispatch(resizeAndScaleCanvas()); dispatch(setActiveTab('unifiedCanvas')); toaster({ @@ -136,11 +132,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { }, [copyImageToClipboard, imageDTO.image_url]); const handleStarImage = useCallback(() => { - if (imageDTO) starImages({ imageDTOs: [imageDTO] }); + if (imageDTO) { + starImages({ imageDTOs: [imageDTO] }); + } }, [starImages, imageDTO]); const handleUnstarImage = useCallback(() => { - if (imageDTO) unstarImages({ imageDTOs: [imageDTO] }); + if (imageDTO) { + unstarImages({ imageDTOs: [imageDTO] }); + } }, [unstarImages, imageDTO]); return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageFallbackSpinner.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageFallbackSpinner.tsx index fd603d3756..95577efc13 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageFallbackSpinner.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageFallbackSpinner.tsx @@ -1,4 +1,5 @@ import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react'; +import { memo } from 'react'; type ImageFallbackSpinnerProps = SpinnerProps; @@ -23,4 +24,4 @@ const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => { ); }; -export default ImageFallbackSpinner; +export default memo(ImageFallbackSpinner); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 804df49b8e..6c34029490 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -18,7 +18,6 @@ import { FaImages, FaServer } from 'react-icons/fa'; import { galleryViewChanged } from '../store/gallerySlice'; import BoardsList from './Boards/BoardsList/BoardsList'; import GalleryBoardName from './GalleryBoardName'; -import GalleryPinButton from './GalleryPinButton'; import GallerySettingsPopover from './GallerySettingsPopover'; import GalleryImageGrid from './ImageGrid/GalleryImageGrid'; @@ -75,7 +74,6 @@ const ImageGalleryContent = () => { onToggle={onToggleBoardList} /> - diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index 5dbbf011e8..40af91d53a 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -88,8 +88,12 @@ const GalleryImage = (props: HoverableImageProps) => { }, []); const starIcon = useMemo(() => { - if (imageDTO?.starred) return ; - if (!imageDTO?.starred && isHovered) return ; + if (imageDTO?.starred) { + return ; + } + if (!imageDTO?.starred && isHovered) { + return ; + } }, [imageDTO?.starred, isHovered]); if (!imageDTO) { diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridItemContainer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridItemContainer.tsx index a09455ef2c..f55ca1dedf 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridItemContainer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridItemContainer.tsx @@ -1,5 +1,5 @@ import { Box, FlexProps, forwardRef } from '@chakra-ui/react'; -import { PropsWithChildren } from 'react'; +import { PropsWithChildren, memo } from 'react'; type ItemContainerProps = PropsWithChildren & FlexProps; const ItemContainer = forwardRef((props: ItemContainerProps, ref) => ( @@ -8,4 +8,4 @@ const ItemContainer = forwardRef((props: ItemContainerProps, ref) => ( )); -export default ItemContainer; +export default memo(ItemContainer); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridListContainer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridListContainer.tsx index fbbca2b2cf..a93222b58e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridListContainer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/ImageGridListContainer.tsx @@ -1,7 +1,7 @@ import { FlexProps, Grid, forwardRef } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { PropsWithChildren } from 'react'; +import { PropsWithChildren, memo } from 'react'; type ListContainerProps = PropsWithChildren & FlexProps; const ListContainer = forwardRef((props: ListContainerProps, ref) => { @@ -23,4 +23,4 @@ const ListContainer = forwardRef((props: ListContainerProps, ref) => { ); }); -export default ListContainer; +export default memo(ListContainer); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/DataViewer.tsx similarity index 53% rename from invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx rename to invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/DataViewer.tsx index 69385607de..87c0957354 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/DataViewer.tsx @@ -1,34 +1,37 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react'; +import { isString } from 'lodash-es'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; -import { useCallback, useMemo } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { FaCopy, FaSave } from 'react-icons/fa'; type Props = { label: string; - jsonObject: object; + data: object | string; fileName?: string; + withDownload?: boolean; + withCopy?: boolean; }; -const ImageMetadataJSON = (props: Props) => { - const { label, jsonObject, fileName } = props; - const jsonString = useMemo( - () => JSON.stringify(jsonObject, null, 2), - [jsonObject] +const DataViewer = (props: Props) => { + const { label, data, fileName, withDownload = true, withCopy = true } = props; + const dataString = useMemo( + () => (isString(data) ? data : JSON.stringify(data, null, 2)), + [data] ); const handleCopy = useCallback(() => { - navigator.clipboard.writeText(jsonString); - }, [jsonString]); + navigator.clipboard.writeText(dataString); + }, [dataString]); const handleSave = useCallback(() => { - const blob = new Blob([jsonString]); + const blob = new Blob([dataString]); const a = document.createElement('a'); a.href = URL.createObjectURL(blob); a.download = `${fileName || label}.json`; document.body.appendChild(a); a.click(); a.remove(); - }, [jsonString, label, fileName]); + }, [dataString, label, fileName]); return ( { }, }} > -

{jsonString}
+
{dataString}
- - } - variant="ghost" - opacity={0.7} - onClick={handleSave} - /> - - - } - variant="ghost" - opacity={0.7} - onClick={handleCopy} - /> - + {withDownload && ( + + } + variant="ghost" + opacity={0.7} + onClick={handleSave} + /> + + )} + {withCopy && ( + + } + variant="ghost" + opacity={0.7} + onClick={handleCopy} + /> + + )} ); }; -export default ImageMetadataJSON; +export default memo(DataViewer); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index c0821c2226..ee5b342d4e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,5 +1,5 @@ import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { useCallback } from 'react'; +import { memo, useCallback } from 'react'; import { UnsafeImageMetadata } from 'services/api/types'; import ImageMetadataItem from './ImageMetadataItem'; @@ -206,4 +206,4 @@ const ImageMetadataActions = (props: Props) => { ); }; -export default ImageMetadataActions; +export default memo(ImageMetadataActions); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx index d72561351f..c03fd26ba1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx @@ -1,5 +1,6 @@ import { ExternalLinkIcon } from '@chakra-ui/icons'; import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react'; +import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaCopy } from 'react-icons/fa'; import { IoArrowUndoCircleOutline } from 'react-icons/io5'; @@ -74,4 +75,4 @@ const ImageMetadataItem = ({ ); }; -export default ImageMetadataItem; +export default memo(ImageMetadataItem); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx index d70aea8a8d..9262d081b5 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx @@ -16,7 +16,7 @@ import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { useDebounce } from 'use-debounce'; import ImageMetadataActions from './ImageMetadataActions'; -import ImageMetadataJSON from './ImageMetadataJSON'; +import DataViewer from './DataViewer'; type ImageMetadataViewerProps = { image: ImageDTO; @@ -79,21 +79,21 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { {metadata ? ( - + ) : ( )} {image ? ( - + ) : ( )} {graph ? ( - + ) : ( )} diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx index 5ba4e711ef..83fddef578 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -5,6 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import ParamLora from './ParamLora'; +import { memo } from 'react'; const selector = createSelector( stateSelector, @@ -29,4 +30,4 @@ const ParamLoraList = () => { ); }; -export default ParamLoraList; +export default memo(ParamLoraList); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index 2046d36ab2..bb485d44b6 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -9,7 +9,7 @@ import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectI import { loraAdded } from 'features/lora/store/loraSlice'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { forEach } from 'lodash-es'; -import { useCallback, useMemo } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; const selector = createSelector( @@ -102,4 +102,4 @@ const ParamLoRASelect = () => { ); }; -export default ParamLoRASelect; +export default memo(ParamLoRASelect); diff --git a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx deleted file mode 100644 index a816762d0f..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx +++ /dev/null @@ -1,140 +0,0 @@ -import { Flex, Text } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppToaster } from 'app/components/Toaster'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; -import { map } from 'lodash-es'; -import { forwardRef, useCallback } from 'react'; -import 'reactflow/dist/style.css'; -import { AnyInvocationType } from 'services/events/types'; -import { useBuildNodeData } from '../hooks/useBuildNodeData'; -import { nodeAdded } from '../store/nodesSlice'; - -type NodeTemplate = { - label: string; - value: string; - description: string; - tags: string[]; -}; - -const selector = createSelector( - [stateSelector], - ({ nodes }) => { - const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { - return { - label: template.title, - value: template.type, - description: template.description, - tags: template.tags, - }; - }); - - data.push({ - label: 'Progress Image', - value: 'current_image', - description: 'Displays the current image in the Node Editor', - tags: ['progress'], - }); - - data.push({ - label: 'Notes', - value: 'notes', - description: 'Add notes about your workflow', - tags: ['notes'], - }); - - return { data }; - }, - defaultSelectorOptions -); - -const AddNodeMenu = () => { - const dispatch = useAppDispatch(); - const { data } = useAppSelector(selector); - - const buildInvocation = useBuildNodeData(); - - const toaster = useAppToaster(); - - const addNode = useCallback( - (nodeType: AnyInvocationType) => { - const invocation = buildInvocation(nodeType); - - if (!invocation) { - toaster({ - status: 'error', - title: `Unknown Invocation type ${nodeType}`, - }); - return; - } - - dispatch(nodeAdded(invocation)); - }, - [dispatch, buildInvocation, toaster] - ); - - const handleChange = useCallback( - (v: string | null) => { - if (!v) { - return; - } - - addNode(v as AnyInvocationType); - }, - [addNode] - ); - - return ( - - - item.label.toLowerCase().includes(value.toLowerCase().trim()) || - item.value.toLowerCase().includes(value.toLowerCase().trim()) || - item.description.toLowerCase().includes(value.toLowerCase().trim()) || - item.tags.includes(value.toLowerCase().trim()) - } - onChange={handleChange} - sx={{ - width: '24rem', - }} - /> - - ); -}; - -interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { - value: string; - label: string; - description: string; -} - -const SelectItem = forwardRef( - ({ label, description, ...others }: ItemProps, ref) => { - return ( -
-
- {label} - - {description} - -
-
- ); - } -); - -SelectItem.displayName = 'SelectItem'; - -export default AddNodeMenu; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx deleted file mode 100644 index f80f0451e4..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx +++ /dev/null @@ -1,199 +0,0 @@ -import { Badge, Flex } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { memo, useMemo } from 'react'; -import { - BaseEdge, - EdgeLabelRenderer, - EdgeProps, - getBezierPath, -} from 'reactflow'; -import { FIELDS, colorTokenToCssVar } from '../types/constants'; -import { isInvocationNode } from '../types/types'; - -const makeEdgeSelector = ( - source: string, - sourceHandleId: string | null | undefined, - target: string, - targetHandleId: string | null | undefined, - selected?: boolean -) => - createSelector( - stateSelector, - ({ nodes }) => { - const sourceNode = nodes.nodes.find((node) => node.id === source); - const targetNode = nodes.nodes.find((node) => node.id === target); - - const isInvocationToInvocationEdge = - isInvocationNode(sourceNode) && isInvocationNode(targetNode); - - const isSelected = - sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge - ? sourceNode?.data?.outputs[sourceHandleId || '']?.type - : undefined; - - const stroke = - sourceType && nodes.shouldColorEdges - ? colorTokenToCssVar(FIELDS[sourceType].color) - : colorTokenToCssVar('base.500'); - - return { - isSelected, - shouldAnimate: nodes.shouldAnimateEdges && isSelected, - stroke, - }; - }, - defaultSelectorOptions - ); - -const CollapsedEdge = memo( - ({ - sourceX, - sourceY, - targetX, - targetY, - sourcePosition, - targetPosition, - markerEnd, - data, - selected, - source, - target, - sourceHandleId, - targetHandleId, - }: EdgeProps<{ count: number }>) => { - const selector = useMemo( - () => - makeEdgeSelector( - source, - sourceHandleId, - target, - targetHandleId, - selected - ), - [selected, source, sourceHandleId, target, targetHandleId] - ); - - const { isSelected, shouldAnimate } = useAppSelector(selector); - - const [edgePath, labelX, labelY] = getBezierPath({ - sourceX, - sourceY, - sourcePosition, - targetX, - targetY, - targetPosition, - }); - - const { base500 } = useChakraThemeTokens(); - - return ( - <> - - {data?.count && data.count > 1 && ( - - - - {data.count} - - - - )} - - ); - } -); - -CollapsedEdge.displayName = 'CollapsedEdge'; - -const DefaultEdge = memo( - ({ - sourceX, - sourceY, - targetX, - targetY, - sourcePosition, - targetPosition, - markerEnd, - selected, - source, - target, - sourceHandleId, - targetHandleId, - }: EdgeProps) => { - const selector = useMemo( - () => - makeEdgeSelector( - source, - sourceHandleId, - target, - targetHandleId, - selected - ), - [source, sourceHandleId, target, targetHandleId, selected] - ); - - const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); - - const [edgePath] = getBezierPath({ - sourceX, - sourceY, - sourcePosition, - targetX, - targetY, - targetPosition, - }); - - return ( - - ); - } -); - -DefaultEdge.displayName = 'DefaultEdge'; - -export const edgeTypes = { - collapsed: CollapsedEdge, - default: DefaultEdge, -}; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx deleted file mode 100644 index be845df435..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import CurrentImageNode from './nodes/CurrentImageNode'; -import InvocationNodeWrapper from './nodes/InvocationNodeWrapper'; -import NotesNode from './nodes/NotesNode'; - -export const nodeTypes = { - invocation: InvocationNodeWrapper, - current_image: CurrentImageNode, - notes: NotesNode, -}; diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx deleted file mode 100644 index 6c610d7f34..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/InvocationNode.tsx +++ /dev/null @@ -1,74 +0,0 @@ -import { Flex } from '@chakra-ui/react'; -import { useFieldNames, useWithFooter } from 'features/nodes/hooks/useNodeData'; -import { memo } from 'react'; -import InputField from '../fields/InputField'; -import OutputField from '../fields/OutputField'; -import NodeFooter from './NodeFooter'; -import NodeHeader from './NodeHeader'; -import NodeWrapper from './NodeWrapper'; - -type Props = { - nodeId: string; - isOpen: boolean; - label: string; - type: string; - selected: boolean; -}; - -const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { - const inputFieldNames = useFieldNames(nodeId, 'input'); - const outputFieldNames = useFieldNames(nodeId, 'output'); - const withFooter = useWithFooter(nodeId); - - return ( - - - {isOpen && ( - <> - - - {outputFieldNames.map((fieldName) => ( - - ))} - {inputFieldNames.map((fieldName) => ( - - ))} - - - {withFooter && } - - )} - - ); -}; - -export default memo(InvocationNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx deleted file mode 100644 index bf12358871..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx +++ /dev/null @@ -1,69 +0,0 @@ -import { Flex } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; -import IAIIconButton from 'common/components/IAIIconButton'; -import IAIPopover from 'common/components/IAIPopover'; -import IAISwitch from 'common/components/IAISwitch'; -import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; -import { InvocationNodeData } from 'features/nodes/types/types'; -import { ChangeEvent, memo, useCallback } from 'react'; -import { FaBars } from 'react-icons/fa'; - -interface Props { - data: InvocationNodeData; -} - -const NodeSettings = (props: Props) => { - const { data } = props; - const dispatch = useAppDispatch(); - - const handleChangeIsIntermediate = useCallback( - (e: ChangeEvent) => { - dispatch( - fieldBooleanValueChanged({ - nodeId: data.id, - fieldName: 'is_intermediate', - value: e.target.checked, - }) - ); - }, - [data.id, dispatch] - ); - - return ( - } - /> - } - > - - - - - ); -}; - -export default memo(NodeSettings); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx deleted file mode 100644 index 68ed0684ed..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx +++ /dev/null @@ -1,99 +0,0 @@ -import { - Box, - ChakraProps, - useColorModeValue, - useToken, -} from '@chakra-ui/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { nodeClicked } from 'features/nodes/store/nodesSlice'; -import { - MouseEvent, - PropsWithChildren, - memo, - useCallback, - useMemo, -} from 'react'; -import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants'; - -const useNodeSelect = (nodeId: string) => { - const dispatch = useAppDispatch(); - - const selectNode = useCallback( - (e: MouseEvent) => { - dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey })); - }, - [dispatch, nodeId] - ); - - return selectNode; -}; - -type NodeWrapperProps = PropsWithChildren & { - nodeId: string; - selected: boolean; - width?: NonNullable['w']; -}; - -const NodeWrapper = (props: NodeWrapperProps) => { - const { width, children, nodeId, selected } = props; - - const [ - nodeSelectedOutlineLight, - nodeSelectedOutlineDark, - shadowsXl, - shadowsBase, - ] = useToken('shadows', [ - 'nodeSelectedOutline.light', - 'nodeSelectedOutline.dark', - 'shadows.xl', - 'shadows.base', - ]); - - const selectNode = useNodeSelect(nodeId); - - const shadow = useColorModeValue( - nodeSelectedOutlineLight, - nodeSelectedOutlineDark - ); - - const shift = useAppSelector((state) => state.hotkeys.shift); - const opacity = useAppSelector((state) => state.nodes.nodeOpacity); - const className = useMemo( - () => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'), - [shift] - ); - - return ( - - - {children} - - ); -}; - -export default memo(NodeWrapper); diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx index 5e610cfc39..4cefdbb20b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx @@ -1,107 +1,95 @@ import { Flex } from '@chakra-ui/react'; import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import ResizeHandle from 'features/ui/components/tabs/ResizeHandle'; -import { memo, useState } from 'react'; -import { MdDeviceHub } from 'react-icons/md'; -import { Panel, PanelGroup } from 'react-resizable-panels'; -import 'reactflow/dist/style.css'; -import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup'; -import { Flow } from './Flow'; import { AnimatePresence, motion } from 'framer-motion'; +import { memo } from 'react'; +import { MdDeviceHub } from 'react-icons/md'; +import 'reactflow/dist/style.css'; +import AddNodePopover from './flow/AddNodePopover/AddNodePopover'; +import { Flow } from './flow/Flow'; +import TopLeftPanel from './flow/panels/TopLeftPanel/TopLeftPanel'; +import TopCenterPanel from './flow/panels/TopCenterPanel/TopCenterPanel'; +import TopRightPanel from './flow/panels/TopRightPanel/TopRightPanel'; +import BottomLeftPanel from './flow/panels/BottomLeftPanel/BottomLeftPanel'; +import MinimapPanel from './flow/panels/MinimapPanel/MinimapPanel'; const NodeEditor = () => { - const [isPanelCollapsed, setIsPanelCollapsed] = useState(false); const isReady = useAppSelector((state) => state.nodes.isReady); return ( - - - - - - - - - {isReady && ( - - - - )} - - - {!isReady && ( - - - - - - )} - - - - + + {isReady && ( + + + + + + + + + + )} + + + {!isReady && ( + + + + + + )} + + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx deleted file mode 100644 index 4525dc5f6b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON'; -import { omit } from 'lodash-es'; -import { useMemo } from 'react'; -import { useDebounce } from 'use-debounce'; -import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph'; - -const useNodesGraph = () => { - const nodes = useAppSelector((state: RootState) => state.nodes); - const [debouncedNodes] = useDebounce(nodes, 300); - const graph = useMemo( - () => omit(buildNodesGraph(debouncedNodes), 'id'), - [debouncedNodes] - ); - - return graph; -}; - -const NodeGraph = () => { - const graph = useNodesGraph(); - - return ; -}; - -export default NodeGraph; diff --git a/invokeai/frontend/web/src/features/nodes/components/editorPanels/BottomLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/BottomLeftPanel.tsx deleted file mode 100644 index 39aa2444c4..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/editorPanels/BottomLeftPanel.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import { memo } from 'react'; -import { Panel } from 'reactflow'; -import ViewportControls from '../ViewportControls'; -import NodeOpacitySlider from '../NodeOpacitySlider'; -import { Flex } from '@chakra-ui/react'; - -const BottomLeftPanel = () => ( - - - - - - -); - -export default memo(BottomLeftPanel); diff --git a/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopCenterPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopCenterPanel.tsx deleted file mode 100644 index 240c2057be..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopCenterPanel.tsx +++ /dev/null @@ -1,24 +0,0 @@ -import { HStack } from '@chakra-ui/react'; -import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton'; -import { memo } from 'react'; -import { Panel } from 'reactflow'; -import NodeEditorSettings from '../NodeEditorSettings'; -import ClearGraphButton from '../ui/ClearGraphButton'; -import NodeInvokeButton from '../ui/NodeInvokeButton'; -import ReloadSchemaButton from '../ui/ReloadSchemaButton'; - -const TopCenterPanel = () => { - return ( - - - - - - - - - - ); -}; - -export default memo(TopCenterPanel); diff --git a/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx deleted file mode 100644 index 2b89db000a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import { memo } from 'react'; -import { Panel } from 'reactflow'; -import AddNodeMenu from '../AddNodeMenu'; - -const TopLeftPanel = () => ( - - - -); - -export default memo(TopLeftPanel); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx deleted file mode 100644 index d9f8f951bc..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx +++ /dev/null @@ -1,47 +0,0 @@ -import { MenuItem, MenuList } from '@chakra-ui/react'; -import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; -import { - InputFieldTemplate, - InputFieldValue, -} from 'features/nodes/types/types'; -import { MouseEvent, useCallback } from 'react'; -import { menuListMotionProps } from 'theme/components/menu'; - -type Props = { - nodeId: string; - field: InputFieldValue; - fieldTemplate: InputFieldTemplate; - children: ContextMenuProps['children']; -}; - -const FieldContextMenu = (props: Props) => { - const skipEvent = useCallback((e: MouseEvent) => { - e.preventDefault(); - }, []); - - return ( - - menuProps={{ - size: 'sm', - isLazy: true, - }} - menuButtonProps={{ - bg: 'transparent', - _hover: { bg: 'transparent' }, - }} - renderMenu={() => ( - - Test - - )} - > - {props.children} - - ); -}; - -export default FieldContextMenu; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx deleted file mode 100644 index 47033baa7b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx +++ /dev/null @@ -1,139 +0,0 @@ -import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; -import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { - useDoesInputHaveValue, - useFieldTemplate, -} from 'features/nodes/hooks/useNodeData'; -import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; -import { PropsWithChildren, memo, useMemo } from 'react'; -import FieldHandle from './FieldHandle'; -import FieldTitle from './FieldTitle'; -import FieldTooltipContent from './FieldTooltipContent'; -import InputFieldRenderer from './InputFieldRenderer'; - -interface Props { - nodeId: string; - fieldName: string; -} - -const InputField = ({ nodeId, fieldName }: Props) => { - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); - - const { - isConnected, - isConnectionInProgress, - isConnectionStartField, - connectionError, - shouldDim, - } = useConnectionState({ nodeId, fieldName, kind: 'input' }); - - const isMissingInput = useMemo(() => { - if (fieldTemplate?.fieldKind !== 'input') { - return false; - } - - if (!fieldTemplate.required) { - return false; - } - - if (!isConnected && fieldTemplate.input === 'connection') { - return true; - } - - if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') { - return true; - } - }, [fieldTemplate, isConnected, doesFieldHaveValue]); - - if (fieldTemplate?.fieldKind !== 'input') { - return ( - - - Unknown input: {fieldName} - - - ); - } - - return ( - - - - } - openDelay={HANDLE_TOOLTIP_OPEN_DELAY} - placement="top" - shouldWrapChildren - hasArrow - > - - - - - - - - {fieldTemplate.input !== 'direct' && ( - - )} - - ); -}; - -export default InputField; - -type InputFieldWrapperProps = PropsWithChildren<{ - shouldDim: boolean; -}>; - -const InputFieldWrapper = memo( - ({ shouldDim, children }: InputFieldWrapperProps) => ( - - {children} - - ) -); - -InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx deleted file mode 100644 index ea4bb76d62..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; -import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; -import { memo } from 'react'; -import FieldTitle from './FieldTitle'; -import FieldTooltipContent from './FieldTooltipContent'; -import InputFieldRenderer from './InputFieldRenderer'; - -type Props = { - nodeId: string; - fieldName: string; -}; - -const LinearViewField = ({ nodeId, fieldName }: Props) => { - return ( - - - - } - openDelay={HANDLE_TOOLTIP_OPEN_DELAY} - placement="top" - shouldWrapChildren - hasArrow - > - - - - - - - - ); -}; - -export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts b/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts deleted file mode 100644 index 5a5e3a9dcf..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/types.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { - InputFieldTemplate, - InputFieldValue, -} from 'features/nodes/types/types'; - -export type FieldComponentProps< - V extends InputFieldValue, - T extends InputFieldTemplate -> = { - nodeId: string; - field: V; - fieldTemplate: T; -}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx new file mode 100644 index 0000000000..83f7482177 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -0,0 +1,205 @@ +import { + Flex, + Popover, + PopoverAnchor, + PopoverBody, + PopoverContent, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppToaster } from 'app/components/Toaster'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { useBuildNodeData } from 'features/nodes/hooks/useBuildNodeData'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, +} from 'features/nodes/store/nodesSlice'; +import { map } from 'lodash-es'; +import { memo, useCallback, useRef } from 'react'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; +import 'reactflow/dist/style.css'; +import { AnyInvocationType } from 'services/events/types'; +import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem'; + +type NodeTemplate = { + label: string; + value: string; + description: string; + tags: string[]; +}; + +const filter = (value: string, item: NodeTemplate) => { + const regex = new RegExp( + value + .trim() + .replace(/[-[\]{}()*+!<=:?./\\^$|#,]/g, '') + .split(' ') + .join('.*'), + 'gi' + ); + return ( + regex.test(item.label) || + regex.test(item.description) || + item.tags.some((tag) => regex.test(tag)) + ); +}; + +const selector = createSelector( + [stateSelector], + ({ nodes }) => { + const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { + return { + label: template.title, + value: template.type, + description: template.description, + tags: template.tags, + }; + }); + + data.push({ + label: 'Progress Image', + value: 'current_image', + description: 'Displays the current image in the Node Editor', + tags: ['progress'], + }); + + data.push({ + label: 'Notes', + value: 'notes', + description: 'Add notes about your workflow', + tags: ['notes'], + }); + + data.sort((a, b) => a.label.localeCompare(b.label)); + + return { data }; + }, + defaultSelectorOptions +); + +const AddNodePopover = () => { + const dispatch = useAppDispatch(); + const buildInvocation = useBuildNodeData(); + const toaster = useAppToaster(); + const { data } = useAppSelector(selector); + const isOpen = useAppSelector((state) => state.nodes.isAddNodePopoverOpen); + const inputRef = useRef(null); + + const addNode = useCallback( + (nodeType: AnyInvocationType) => { + const invocation = buildInvocation(nodeType); + + if (!invocation) { + toaster({ + status: 'error', + title: `Unknown Invocation type ${nodeType}`, + }); + return; + } + + dispatch(nodeAdded(invocation)); + }, + [dispatch, buildInvocation, toaster] + ); + + const handleChange = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + addNode(v as AnyInvocationType); + }, + [addNode] + ); + + const onClose = useCallback(() => { + dispatch(addNodePopoverClosed()); + }, [dispatch]); + + const onOpen = useCallback(() => { + dispatch(addNodePopoverOpened()); + }, [dispatch]); + + const handleHotkeyOpen: HotkeyCallback = useCallback( + (e) => { + e.preventDefault(); + onOpen(); + setTimeout(() => { + inputRef.current?.focus(); + }, 0); + }, + [onOpen] + ); + + const handleHotkeyClose: HotkeyCallback = useCallback(() => { + onClose(); + }, [onClose]); + + useHotkeys(['shift+a', 'space'], handleHotkeyOpen); + useHotkeys(['escape'], handleHotkeyClose); + + return ( + + + + + + + + + + + ); +}; + +export default memo(AddNodePopover); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopoverSelectItem.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopoverSelectItem.tsx new file mode 100644 index 0000000000..95b033f95c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopoverSelectItem.tsx @@ -0,0 +1,29 @@ +import { Text } from '@chakra-ui/react'; +import { forwardRef } from 'react'; +import 'reactflow/dist/style.css'; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description: string; +} + +export const AddNodePopoverSelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + + {description} + +
+
+ ); + } +); + +AddNodePopoverSelectItem.displayName = 'AddNodePopoverSelectItem'; diff --git a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx similarity index 66% rename from invokeai/frontend/web/src/features/nodes/components/Flow.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 3290a65054..e8fb66d074 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,7 +1,11 @@ import { useToken } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { contextMenusClosed } from 'features/ui/store/uiSlice'; import { useCallback } from 'react'; +import { useHotkeys } from 'react-hotkeys-hook'; import { Background, OnConnect, @@ -16,7 +20,7 @@ import { ProOptions, ReactFlow, } from 'reactflow'; -import { useIsValidConnection } from '../hooks/useIsValidConnection'; +import { useIsValidConnection } from '../../hooks/useIsValidConnection'; import { connectionEnded, connectionMade, @@ -25,30 +29,54 @@ import { edgesDeleted, nodesChanged, nodesDeleted, + selectedAll, selectedEdgesChanged, selectedNodesChanged, + selectionCopied, + selectionPasted, viewportChanged, -} from '../store/nodesSlice'; -import { CustomConnectionLine } from './CustomConnectionLine'; -import { edgeTypes } from './CustomEdges'; -import { nodeTypes } from './CustomNodes'; -import BottomLeftPanel from './editorPanels/BottomLeftPanel'; -import MinimapPanel from './editorPanels/MinimapPanel'; -import TopCenterPanel from './editorPanels/TopCenterPanel'; -import TopLeftPanel from './editorPanels/TopLeftPanel'; -import TopRightPanel from './editorPanels/TopRightPanel'; +} from '../../store/nodesSlice'; +import CustomConnectionLine from './connectionLines/CustomConnectionLine'; +import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; +import InvocationDefaultEdge from './edges/InvocationDefaultEdge'; +import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; +import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; +import NotesNode from './nodes/Notes/NotesNode'; + +const DELETE_KEYS = ['Delete', 'Backspace']; + +const edgeTypes = { + collapsed: InvocationCollapsedEdge, + default: InvocationDefaultEdge, +}; + +const nodeTypes = { + invocation: InvocationNodeWrapper, + current_image: CurrentImageNode, + notes: NotesNode, +}; // TODO: can we support reactflow? if not, we could style the attribution so it matches the app const proOptions: ProOptions = { hideAttribution: true }; +const selector = createSelector( + stateSelector, + ({ nodes }) => { + const { shouldSnapToGrid, selectionMode } = nodes; + return { + shouldSnapToGrid, + selectionMode, + }; + }, + defaultSelectorOptions +); + export const Flow = () => { const dispatch = useAppDispatch(); const nodes = useAppSelector((state) => state.nodes.nodes); const edges = useAppSelector((state) => state.nodes.edges); const viewport = useAppSelector((state) => state.nodes.viewport); - const shouldSnapToGrid = useAppSelector( - (state) => state.nodes.shouldSnapToGrid - ); + const { shouldSnapToGrid, selectionMode } = useAppSelector(selector); const isValidConnection = useIsValidConnection(); @@ -119,8 +147,24 @@ export const Flow = () => { dispatch(contextMenusClosed()); }, [dispatch]); + useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { + e.preventDefault(); + dispatch(selectionCopied()); + }); + + useHotkeys(['Ctrl+a', 'Meta+a'], (e) => { + e.preventDefault(); + dispatch(selectedAll()); + }); + + useHotkeys(['Ctrl+v', 'Meta+v'], (e) => { + e.preventDefault(); + dispatch(selectionPasted()); + }); + return ( { proOptions={proOptions} style={{ borderRadius }} onPaneClick={handlePaneClick} + deleteKeyCode={DELETE_KEYS} + selectionMode={selectionMode} > - - - - - ); diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx similarity index 85% rename from invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index 678d8e3d1d..a379be7ee2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -1,8 +1,10 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { FIELDS } from 'features/nodes/types/constants'; +import { memo } from 'react'; import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; -import { FIELDS, colorTokenToCssVar } from '../types/constants'; const selector = createSelector(stateSelector, ({ nodes }) => { const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = @@ -25,7 +27,7 @@ const selector = createSelector(stateSelector, ({ nodes }) => { }; }); -export const CustomConnectionLine = ({ +const CustomConnectionLine = ({ fromX, fromY, fromPosition, @@ -59,3 +61,5 @@ export const CustomConnectionLine = ({ ); }; + +export default memo(CustomConnectionLine); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx new file mode 100644 index 0000000000..fca38def34 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx @@ -0,0 +1,94 @@ +import { Badge, Flex } from '@chakra-ui/react'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { memo, useMemo } from 'react'; +import { + BaseEdge, + EdgeLabelRenderer, + EdgeProps, + getBezierPath, +} from 'reactflow'; +import { makeEdgeSelector } from './util/makeEdgeSelector'; + +const InvocationCollapsedEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + data, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps<{ count: number }>) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [selected, source, sourceHandleId, target, targetHandleId] + ); + + const { isSelected, shouldAnimate } = useAppSelector(selector); + + const [edgePath, labelX, labelY] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + const { base500 } = useChakraThemeTokens(); + + return ( + <> + + {data?.count && data.count > 1 && ( + + + + {data.count} + + + + )} + + ); +}; + +export default memo(InvocationCollapsedEdge); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx new file mode 100644 index 0000000000..effefb12ab --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx @@ -0,0 +1,58 @@ +import { useAppSelector } from 'app/store/storeHooks'; +import { memo, useMemo } from 'react'; +import { BaseEdge, EdgeProps, getBezierPath } from 'reactflow'; +import { makeEdgeSelector } from './util/makeEdgeSelector'; + +const InvocationDefaultEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [source, sourceHandleId, target, targetHandleId, selected] + ); + + const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); + + const [edgePath] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + return ( + + ); +}; + +export default memo(InvocationDefaultEdge); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts new file mode 100644 index 0000000000..b5dc484eae --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -0,0 +1,42 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { FIELDS } from 'features/nodes/types/constants'; +import { isInvocationNode } from 'features/nodes/types/types'; + +export const makeEdgeSelector = ( + source: string, + sourceHandleId: string | null | undefined, + target: string, + targetHandleId: string | null | undefined, + selected?: boolean +) => + createSelector( + stateSelector, + ({ nodes }) => { + const sourceNode = nodes.nodes.find((node) => node.id === source); + const targetNode = nodes.nodes.find((node) => node.id === target); + + const isInvocationToInvocationEdge = + isInvocationNode(sourceNode) && isInvocationNode(targetNode); + + const isSelected = + sourceNode?.selected || targetNode?.selected || selected; + const sourceType = isInvocationToInvocationEdge + ? sourceNode?.data?.outputs[sourceHandleId || '']?.type + : undefined; + + const stroke = + sourceType && nodes.shouldColorEdges + ? colorTokenToCssVar(FIELDS[sourceType].color) + : colorTokenToCssVar('base.500'); + + return { + isSelected, + shouldAnimate: nodes.shouldAnimateEdges && isSelected, + stroke, + }; + }, + defaultSelectorOptions + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx similarity index 97% rename from invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx index 985978f72d..6a8a2a3552 100644 --- a/invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx @@ -7,7 +7,7 @@ import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { PropsWithChildren, memo } from 'react'; import { useSelector } from 'react-redux'; import { NodeProps } from 'reactflow'; -import NodeWrapper from '../Invocation/NodeWrapper'; +import NodeWrapper from '../common/NodeWrapper'; const selector = createSelector(stateSelector, ({ system, gallery }) => { const imageDTO = gallery.selection[gallery.selection.length - 1]; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx new file mode 100644 index 0000000000..8f6a2531a0 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx @@ -0,0 +1,86 @@ +import { Flex, Grid, GridItem } from '@chakra-ui/react'; +import { memo } from 'react'; +import InvocationNodeFooter from './InvocationNodeFooter'; +import InvocationNodeHeader from './InvocationNodeHeader'; +import NodeWrapper from '../common/NodeWrapper'; +import OutputField from './fields/OutputField'; +import InputField from './fields/InputField'; +import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames'; +import { useWithFooter } from 'features/nodes/hooks/useWithFooter'; +import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames'; +import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames'; + +type Props = { + nodeId: string; + isOpen: boolean; + label: string; + type: string; + selected: boolean; +}; + +const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { + const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId); + const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId); + const outputFieldNames = useOutputFieldNames(nodeId); + const withFooter = useWithFooter(nodeId); + + return ( + + + {isOpen && ( + <> + + + + {inputConnectionFieldNames.map((fieldName, i) => ( + + + + ))} + {outputFieldNames.map((fieldName, i) => ( + + + + ))} + + {inputAnyOrDirectFieldNames.map((fieldName) => ( + + ))} + + + {withFooter && } + + )} + + ); +}; + +export default memo(InvocationNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx similarity index 94% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index 32dd554ef4..30e02bfd84 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -10,7 +10,7 @@ interface Props { nodeId: string; } -const NodeCollapsedHandles = ({ nodeId }: Props) => { +const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { const data = useNodeData(nodeId); const { base400, base600 } = useChakraThemeTokens(); const backgroundColor = useColorModeValue(base400, base600); @@ -71,4 +71,4 @@ const NodeCollapsedHandles = ({ nodeId }: Props) => { ); }; -export default memo(NodeCollapsedHandles); +export default memo(InvocationNodeCollapsedHandles); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx similarity index 86% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx index c858872b57..ffcdd13fef 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx @@ -6,10 +6,8 @@ import { Spacer, } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; -import { - useHasImageOutput, - useIsIntermediate, -} from 'features/nodes/hooks/useNodeData'; +import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput'; +import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate'; import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { ChangeEvent, memo, useCallback } from 'react'; @@ -18,7 +16,7 @@ type Props = { nodeId: string; }; -const NodeFooter = ({ nodeId }: Props) => { +const InvocationNodeFooter = ({ nodeId }: Props) => { return ( { ); }; -export default memo(NodeFooter); +export default memo(InvocationNodeFooter); const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => { const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx similarity index 52% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx index ea503a8f27..cd6c5215d1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeHeader.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx @@ -1,10 +1,10 @@ import { Flex } from '@chakra-ui/react'; import { memo } from 'react'; -import NodeCollapseButton from '../Invocation/NodeCollapseButton'; -import NodeCollapsedHandles from '../Invocation/NodeCollapsedHandles'; -import NodeNotesEdit from '../Invocation/NodeNotesEdit'; -import NodeStatusIndicator from '../Invocation/NodeStatusIndicator'; -import NodeTitle from '../Invocation/NodeTitle'; +import NodeCollapseButton from '../common/NodeCollapseButton'; +import NodeTitle from '../common/NodeTitle'; +import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles'; +import InvocationNodeNotes from './InvocationNodeNotes'; +import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator'; type Props = { nodeId: string; @@ -14,7 +14,7 @@ type Props = { selected: boolean; }; -const NodeHeader = ({ nodeId, isOpen }: Props) => { +const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => { return ( { justifyContent: 'space-between', h: 8, textAlign: 'center', - fontWeight: 600, + fontWeight: 500, color: 'base.700', _dark: { color: 'base.200' }, }} @@ -33,12 +33,12 @@ const NodeHeader = ({ nodeId, isOpen }: Props) => { - - + + - {!isOpen && } + {!isOpen && } ); }; -export default memo(NodeHeader); +export default memo(InvocationNodeHeader); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx index fa5a9d76fb..aca5f75224 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeNotes.tsx @@ -16,14 +16,11 @@ import { } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; -import { - useNodeData, - useNodeLabel, - useNodeTemplate, - useNodeTemplateTitle, -} from 'features/nodes/hooks/useNodeData'; +import { useNodeData } from 'features/nodes/hooks/useNodeData'; +import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle'; import { nodeNotesChanged } from 'features/nodes/store/nodesSlice'; -import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { isInvocationNodeData } from 'features/nodes/types/types'; import { ChangeEvent, memo, useCallback } from 'react'; import { FaInfoCircle } from 'react-icons/fa'; @@ -32,7 +29,7 @@ interface Props { nodeId: string; } -const NodeNotesEdit = ({ nodeId }: Props) => { +const InvocationNodeNotes = ({ nodeId }: Props) => { const { isOpen, onOpen, onClose } = useDisclosure(); const label = useNodeLabel(nodeId); const title = useNodeTemplateTitle(nodeId); @@ -45,7 +42,7 @@ const NodeNotesEdit = ({ nodeId }: Props) => { shouldWrapChildren > { ); }; -export default memo(NodeNotesEdit); +export default memo(InvocationNodeNotes); const TooltipContent = memo(({ nodeId }: { nodeId: string }) => { const data = useNodeData(nodeId); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx similarity index 97% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx index d53fec4b42..6e1da90ad8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx @@ -28,7 +28,7 @@ const circleStyles = { '.chakra-progress__track': { stroke: 'transparent' }, }; -const NodeStatusIndicator = ({ nodeId }: Props) => { +const InvocationNodeStatusIndicator = ({ nodeId }: Props) => { const selectNodeExecutionState = useMemo( () => createSelector( @@ -64,7 +64,7 @@ const NodeStatusIndicator = ({ nodeId }: Props) => { ); }; -export default memo(NodeStatusIndicator); +export default memo(InvocationNodeStatusIndicator); type TooltipLabelProps = { nodeExecutionState: NodeExecutionState; diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeUnknownFallback.tsx similarity index 88% rename from invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeUnknownFallback.tsx index 664a788b5a..7ec59f00f0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Invocation/UnknownNodeFallback.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeUnknownFallback.tsx @@ -1,8 +1,8 @@ import { Box, Flex, Text } from '@chakra-ui/react'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { memo } from 'react'; -import NodeCollapseButton from '../Invocation/NodeCollapseButton'; -import NodeWrapper from '../Invocation/NodeWrapper'; +import NodeCollapseButton from '../common/NodeCollapseButton'; +import NodeWrapper from '../common/NodeWrapper'; type Props = { nodeId: string; @@ -12,7 +12,7 @@ type Props = { selected: boolean; }; -const UnknownNodeFallback = ({ +const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, @@ -72,4 +72,4 @@ const UnknownNodeFallback = ({ ); }; -export default memo(UnknownNodeFallback); +export default memo(InvocationNodeUnknownFallback); diff --git a/invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx similarity index 90% rename from invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNodeWrapper.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx index 26bda27d8b..3c79eac1d3 100644 --- a/invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx @@ -5,7 +5,7 @@ import { InvocationNodeData } from 'features/nodes/types/types'; import { memo, useMemo } from 'react'; import { NodeProps } from 'reactflow'; import InvocationNode from '../Invocation/InvocationNode'; -import UnknownNodeFallback from '../Invocation/UnknownNodeFallback'; +import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback'; const InvocationNodeWrapper = (props: NodeProps) => { const { data, selected } = props; @@ -23,7 +23,7 @@ const InvocationNodeWrapper = (props: NodeProps) => { if (!nodeTemplate) { return ( - ['children']; +}; + +const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => { + const dispatch = useAppDispatch(); + const label = useFieldLabel(nodeId, fieldName); + const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind); + const input = useFieldInputKind(nodeId, fieldName); + + const skipEvent = useCallback((e: MouseEvent) => { + e.preventDefault(); + }, []); + + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const isExposed = Boolean( + nodes.workflow.exposedFields.find( + (f) => f.nodeId === nodeId && f.fieldName === fieldName + ) + ); + + return { isExposed }; + }, + defaultSelectorOptions + ), + [fieldName, nodeId] + ); + + const mayExpose = useMemo( + () => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'), + [input] + ); + + const { isExposed } = useAppSelector(selector); + + const handleExposeField = useCallback(() => { + dispatch(workflowExposedFieldAdded({ nodeId, fieldName })); + }, [dispatch, fieldName, nodeId]); + + const handleUnexposeField = useCallback(() => { + dispatch(workflowExposedFieldRemoved({ nodeId, fieldName })); + }, [dispatch, fieldName, nodeId]); + + const menuItems = useMemo(() => { + const menuItems: ReactNode[] = []; + if (mayExpose && !isExposed) { + menuItems.push( + } + onClick={handleExposeField} + > + Add to Linear View + + ); + } + if (mayExpose && isExposed) { + menuItems.push( + } + onClick={handleUnexposeField} + > + Remove from Linear View + + ); + } + return menuItems; + }, [ + fieldName, + handleExposeField, + handleUnexposeField, + isExposed, + mayExpose, + nodeId, + ]); + + return ( + + menuProps={{ + size: 'sm', + isLazy: true, + }} + menuButtonProps={{ + bg: 'transparent', + _hover: { bg: 'transparent' }, + }} + renderMenu={() => + !menuItems.length ? null : ( + + + {menuItems} + + + ) + } + > + {children} + + ); +}; + +export default memo(FieldContextMenu); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx similarity index 92% rename from invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index f79a57a4eb..14924a16fe 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,12 +1,15 @@ import { Tooltip } from '@chakra-ui/react'; -import { CSSProperties, memo, useMemo } from 'react'; -import { Handle, HandleType, Position } from 'reactflow'; +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, - colorTokenToCssVar, -} from '../../types/constants'; -import { InputFieldTemplate, OutputFieldTemplate } from '../../types/types'; +} from 'features/nodes/types/constants'; +import { + InputFieldTemplate, + OutputFieldTemplate, +} from 'features/nodes/types/types'; +import { CSSProperties, memo, useMemo } from 'react'; +import { Handle, HandleType, Position } from 'reactflow'; export const handleBaseStyles: CSSProperties = { position: 'absolute', diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldTitle.tsx similarity index 52% rename from invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldTitle.tsx index e9a49989f6..7a0ee62a88 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldTitle.tsx @@ -3,63 +3,41 @@ import { EditableInput, EditablePreview, Flex, + forwardRef, useEditableControls, } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; -import IAIDraggable from 'common/components/IAIDraggable'; -import { NodeFieldDraggableData } from 'features/dnd/types'; -import { - useFieldData, - useFieldTemplate, -} from 'features/nodes/hooks/useNodeData'; +import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; +import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; import { fieldLabelChanged } from 'features/nodes/store/nodesSlice'; -import { - MouseEvent, - memo, - useCallback, - useEffect, - useMemo, - useState, -} from 'react'; +import { MouseEvent, memo, useCallback, useEffect, useState } from 'react'; interface Props { nodeId: string; fieldName: string; - isDraggable?: boolean; kind: 'input' | 'output'; + isMissingInput?: boolean; } -const FieldTitle = (props: Props) => { - const { nodeId, fieldName, isDraggable = false, kind } = props; - const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); - const field = useFieldData(nodeId, fieldName); +const FieldTitle = forwardRef((props: Props, ref) => { + const { nodeId, fieldName, kind, isMissingInput = false } = props; + const label = useFieldLabel(nodeId, fieldName); + const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind); const dispatch = useAppDispatch(); const [localTitle, setLocalTitle] = useState( - field?.label || fieldTemplate?.title || 'Unknown Field' - ); - - const draggableData: NodeFieldDraggableData | undefined = useMemo( - () => - field && - fieldTemplate?.fieldKind === 'input' && - fieldTemplate?.input !== 'connection' && - isDraggable - ? { - id: `${nodeId}-${fieldName}`, - payloadType: 'NODE_FIELD', - payload: { nodeId, field, fieldTemplate }, - } - : undefined, - [field, fieldName, fieldTemplate, isDraggable, nodeId] + label || fieldTemplateTitle || 'Unknown Field' ); const handleSubmit = useCallback( async (newTitle: string) => { + if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) { + return; + } + setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field'); dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle })); - setLocalTitle(newTitle || fieldTemplate?.title || 'Unknown Field'); }, - [dispatch, nodeId, fieldName, fieldTemplate?.title] + [label, fieldTemplateTitle, dispatch, nodeId, fieldName] ); const handleChange = useCallback((newTitle: string) => { @@ -68,39 +46,54 @@ const FieldTitle = (props: Props) => { useEffect(() => { // Another component may change the title; sync local title with global state - setLocalTitle(field?.label || fieldTemplate?.title || 'Unknown Field'); - }, [field?.label, fieldTemplate?.title]); + setLocalTitle(label || fieldTemplateTitle || 'Unknown Field'); + }, [label, fieldTemplateTitle]); return ( { }, }} /> - + ); -}; +}); export default memo(FieldTitle); -type EditableControlsProps = { - draggableData?: NodeFieldDraggableData; -}; - -const EditableControls = memo((props: EditableControlsProps) => { +const EditableControls = memo(() => { const { isEditing, getEditButtonProps } = useEditableControls(); - const handleDoubleClick = useCallback( + const handleClick = useCallback( (e: MouseEvent) => { const { onClick } = getEditButtonProps(); if (!onClick) { return; } onClick(e); + e.preventDefault(); }, [getEditButtonProps] ); @@ -137,19 +127,9 @@ const EditableControls = memo((props: EditableControlsProps) => { return null; } - if (props.draggableData) { - return ( - - ); - } - return ( { const isInputTemplate = isInputFieldTemplate(fieldTemplate); const fieldTitle = useMemo(() => { if (isInputFieldValue(field)) { - if (field.label && fieldTemplate) { + if (field.label && fieldTemplate?.title) { return `${field.label} (${fieldTemplate.title})`; } @@ -53,4 +51,4 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => { ); }; -export default FieldTooltipContent; +export default memo(FieldTooltipContent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx new file mode 100644 index 0000000000..3758ae4114 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -0,0 +1,178 @@ +import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; +import SelectionOverlay from 'common/components/SelectionOverlay'; +import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; +import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue'; +import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind'; +import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { PropsWithChildren, memo, useMemo } from 'react'; +import FieldContextMenu from './FieldContextMenu'; +import FieldHandle from './FieldHandle'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +interface Props { + nodeId: string; + fieldName: string; +} + +const InputField = ({ nodeId, fieldName }: Props) => { + const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); + const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); + const input = useFieldInputKind(nodeId, fieldName); + + const { + isConnected, + isConnectionInProgress, + isConnectionStartField, + connectionError, + shouldDim, + } = useConnectionState({ nodeId, fieldName, kind: 'input' }); + + const isMissingInput = useMemo(() => { + if (fieldTemplate?.fieldKind !== 'input') { + return false; + } + + if (!fieldTemplate.required) { + return false; + } + + if (!isConnected && fieldTemplate.input === 'connection') { + return true; + } + + if (!doesFieldHaveValue && !isConnected && fieldTemplate.input === 'any') { + return true; + } + }, [fieldTemplate, isConnected, doesFieldHaveValue]); + + if (fieldTemplate?.fieldKind !== 'input') { + return ( + + + Unknown input: {fieldName} + + + ); + } + + return ( + + + + {(ref) => ( + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + hasArrow + > + + + + + )} + + + + + + + {fieldTemplate.input !== 'direct' && ( + + )} + + ); +}; + +export default memo(InputField); + +type InputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; + nodeId: string; + fieldName: string; +}>; + +const InputFieldWrapper = memo( + ({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => { + const { isMouseOverField, handleMouseOver, handleMouseOut } = + useIsMouseOverField(nodeId, fieldName); + + return ( + + {children} + + + ); + } +); + +InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx similarity index 73% rename from invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index acec921d8e..9b3ce100c8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,30 +1,29 @@ -import { Box } from '@chakra-ui/react'; -import { - useFieldData, - useFieldTemplate, -} from 'features/nodes/hooks/useNodeData'; +import { Box, Text } from '@chakra-ui/react'; +import { useFieldData } from 'features/nodes/hooks/useFieldData'; +import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; import { memo } from 'react'; -import BooleanInputField from './fieldTypes/BooleanInputField'; -import ClipInputField from './fieldTypes/ClipInputField'; -import CollectionInputField from './fieldTypes/CollectionInputField'; -import CollectionItemInputField from './fieldTypes/CollectionItemInputField'; -import ColorInputField from './fieldTypes/ColorInputField'; -import ConditioningInputField from './fieldTypes/ConditioningInputField'; -import ControlInputField from './fieldTypes/ControlInputField'; -import ControlNetModelInputField from './fieldTypes/ControlNetModelInputField'; -import EnumInputField from './fieldTypes/EnumInputField'; -import ImageCollectionInputField from './fieldTypes/ImageCollectionInputField'; -import ImageInputField from './fieldTypes/ImageInputField'; -import LatentsInputField from './fieldTypes/LatentsInputField'; -import LoRAModelInputField from './fieldTypes/LoRAModelInputField'; -import MainModelInputField from './fieldTypes/MainModelInputField'; -import NumberInputField from './fieldTypes/NumberInputField'; -import RefinerModelInputField from './fieldTypes/RefinerModelInputField'; -import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField'; -import StringInputField from './fieldTypes/StringInputField'; -import UnetInputField from './fieldTypes/UnetInputField'; -import VaeInputField from './fieldTypes/VaeInputField'; -import VaeModelInputField from './fieldTypes/VaeModelInputField'; +import BooleanInputField from './inputs/BooleanInputField'; +import ClipInputField from './inputs/ClipInputField'; +import CollectionInputField from './inputs/CollectionInputField'; +import CollectionItemInputField from './inputs/CollectionItemInputField'; +import ColorInputField from './inputs/ColorInputField'; +import ConditioningInputField from './inputs/ConditioningInputField'; +import ControlInputField from './inputs/ControlInputField'; +import ControlNetModelInputField from './inputs/ControlNetModelInputField'; +import EnumInputField from './inputs/EnumInputField'; +import ImageCollectionInputField from './inputs/ImageCollectionInputField'; +import ImageInputField from './inputs/ImageInputField'; +import LatentsInputField from './inputs/LatentsInputField'; +import LoRAModelInputField from './inputs/LoRAModelInputField'; +import MainModelInputField from './inputs/MainModelInputField'; +import NumberInputField from './inputs/NumberInputField'; +import RefinerModelInputField from './inputs/RefinerModelInputField'; +import SDXLMainModelInputField from './inputs/SDXLMainModelInputField'; +import SchedulerInputField from './inputs/SchedulerInputField'; +import StringInputField from './inputs/StringInputField'; +import UnetInputField from './inputs/UnetInputField'; +import VaeInputField from './inputs/VaeInputField'; +import VaeModelInputField from './inputs/VaeModelInputField'; type InputFieldProps = { nodeId: string; @@ -286,7 +285,30 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { ); } - return Unknown field type: {field?.type}; + if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') { + return ( + + ); + } + + return ( + + + Unknown field type: {field?.type} + + + ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx new file mode 100644 index 0000000000..cbf4a19137 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -0,0 +1,82 @@ +import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import SelectionOverlay from 'common/components/SelectionOverlay'; +import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField'; +import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { memo, useCallback } from 'react'; +import { FaInfoCircle, FaTrash } from 'react-icons/fa'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +type Props = { + nodeId: string; + fieldName: string; +}; + +const LinearViewField = ({ nodeId, fieldName }: Props) => { + const dispatch = useAppDispatch(); + const { isMouseOverField, handleMouseOut, handleMouseOver } = + useIsMouseOverField(nodeId, fieldName); + + const handleRemoveField = useCallback(() => { + dispatch(workflowExposedFieldRemoved({ nodeId, fieldName })); + }, [dispatch, fieldName, nodeId]); + + return ( + + + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + hasArrow + > + + + + + } + /> + + + + + + ); +}; + +export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx similarity index 92% rename from invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 2a257d741e..e717423f65 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,12 +1,6 @@ -import { - Flex, - FormControl, - FormLabel, - Spacer, - Tooltip, -} from '@chakra-ui/react'; +import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldTemplate } from 'features/nodes/hooks/useNodeData'; +import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { PropsWithChildren, memo } from 'react'; import FieldHandle from './FieldHandle'; @@ -42,7 +36,6 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( - {children} diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx similarity index 86% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx index daf2f598ba..c9f83403f6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/BooleanInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx @@ -4,9 +4,9 @@ import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { BooleanInputFieldTemplate, BooleanInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { ChangeEvent, memo, useCallback } from 'react'; -import { FieldComponentProps } from './types'; const BooleanInputFieldComponent = ( props: FieldComponentProps @@ -29,7 +29,11 @@ const BooleanInputFieldComponent = ( ); return ( - + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ClipInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ClipInputField.tsx similarity index 86% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ClipInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ClipInputField.tsx index 37c3db3d11..cf5d7fae95 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ClipInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ClipInputField.tsx @@ -1,9 +1,9 @@ import { ClipInputFieldTemplate, ClipInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo } from 'react'; -import { FieldComponentProps } from './types'; const ClipInputFieldComponent = ( _props: FieldComponentProps diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionInputField.tsx similarity index 88% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionInputField.tsx index 99c88af2cb..7cbc46f28c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionInputField.tsx @@ -1,9 +1,9 @@ import { CollectionInputFieldTemplate, CollectionInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo } from 'react'; -import { FieldComponentProps } from './types'; const CollectionInputFieldComponent = ( _props: FieldComponentProps< diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionItemInputField.tsx similarity index 88% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionItemInputField.tsx index 00f753d8d3..e67a20bdfb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/CollectionItemInputField.tsx @@ -1,9 +1,9 @@ import { CollectionItemInputFieldTemplate, CollectionItemInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo } from 'react'; -import { FieldComponentProps } from './types'; const CollectionItemInputFieldComponent = ( _props: FieldComponentProps< diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx similarity index 95% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx index 422c3ba48f..c2af279cb5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ColorInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx @@ -3,10 +3,10 @@ import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice'; import { ColorInputFieldTemplate, ColorInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo, useCallback } from 'react'; import { RgbaColor, RgbaColorPicker } from 'react-colorful'; -import { FieldComponentProps } from './types'; const ColorInputFieldComponent = ( props: FieldComponentProps diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ConditioningInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ConditioningInputField.tsx similarity index 88% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ConditioningInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ConditioningInputField.tsx index e280251cd3..9d174f40c5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ConditioningInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ConditioningInputField.tsx @@ -1,9 +1,9 @@ import { ConditioningInputFieldTemplate, ConditioningInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo } from 'react'; -import { FieldComponentProps } from './types'; const ConditioningInputFieldComponent = ( _props: FieldComponentProps< diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlInputField.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlInputField.tsx index 6b2b3deafb..346dd49b21 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlInputField.tsx @@ -1,9 +1,9 @@ import { ControlInputFieldTemplate, ControlInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { memo } from 'react'; -import { FieldComponentProps } from './types'; const ControlInputFieldComponent = ( _props: FieldComponentProps diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlNetModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx similarity index 97% rename from invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlNetModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx index 492ec51d20..f66c8b0cfd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/ControlNetModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx @@ -5,13 +5,13 @@ import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlic import { ControlNetModelInputFieldTemplate, ControlNetModelInputFieldValue, + FieldComponentProps, } from 'features/nodes/types/types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; -import { FieldComponentProps } from './types'; const ControlNetModelInputFieldComponent = ( props: FieldComponentProps< @@ -85,7 +85,7 @@ const ControlNetModelInputFieldComponent = ( return ( @@ -30,7 +30,7 @@ const EnumInputFieldComponent = ( return (