diff --git a/.gitignore b/.gitignore index 7f3b1278df..e9918d4fb5 100644 --- a/.gitignore +++ b/.gitignore @@ -201,8 +201,6 @@ checkpoints # If it's a Mac .DS_Store -invokeai/frontend/web/dist/* - # Let the frontend manage its own gitignore !invokeai/frontend/web/* diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 0b03c8e729..dcbdbec82d 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,17 +2,17 @@ from typing import Literal, Optional, Union -from fastapi import Query +from fastapi import Query, Body from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType +from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) - class VaeRepo(BaseModel): repo_id: str = Field(description="The repo ID to use for this VAE") path: Optional[str] = Field(description="The path to the VAE") @@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel): info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") status: str = Field(description="The status of the API response") -class ImportModelRequest(BaseModel): - name: str = Field(description="A model path, repo_id or URL to import") - prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files') +class ImportModelResponse(BaseModel): + name: str = Field(description="The name of the imported model") +# base_model: str = Field(description="The base model") +# model_type: str = Field(description="The model type") + info: AddModelResult = Field(description="The model info") + status: str = Field(description="The status of the API response") class ConversionRequest(BaseModel): name: str = Field(description="The name of the new model") @@ -86,7 +89,6 @@ async def list_models( models = parse_obj_as(ModelsList, { "models": models_raw }) return models - @models_router.post( "/", operation_id="update_model", @@ -109,27 +111,38 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", - responses={200: {"status": "success"}}, + responses= { + 201: {"description" : "The model imported successfully"}, + 404: {"description" : "The model could not be found"}, + }, + status_code=201, + response_model=ImportModelResponse ) async def import_model( - model_request: ImportModelRequest -) -> None: - """ Add Model """ - items_to_import = set([model_request.name]) + name: str = Query(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), +) -> ImportModelResponse: + """ Add a model using its local path, repo_id, or remote URL """ + items_to_import = {name} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type) + prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - if len(installed_models) > 0: - logger.info(f'Successfully imported {model_request.name}') + if info := installed_models.get(name): + logger.info(f'Successfully imported {name}, got {info}') + return ImportModelResponse( + name = name, + info = info, + status = "success", + ) else: - logger.error(f'Model {model_request.name} not imported') - raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported') + logger.error(f'Model {name} not imported') + raise HTTPException(status_code=404, detail=f'Model {name} not found') @models_router.delete( "/{model_name}", diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 4ce3e839b6..4c7314bd2b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING +from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, + get_type_hints) -from pydantic import BaseModel, Field +from pydantic import BaseConfig, BaseModel, Field if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - + return dict( + map( + lambda t: (get_args(get_type_hints(t)["type"])[0], t), + BaseInvocation.get_all_subclasses(), + ) + ) + @classmethod def get_output_type(cls): return signature(cls.invoke).return_annotation @@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel): def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - - #fmt: off + + # fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - #fmt: on + # fmt: on # TODO: figure out a better way to provide these hints @@ -97,16 +103,20 @@ class UIConfig(TypedDict, total=False): "latents", "model", "control", + "image_collection", + "vae_model", + "lora_model", ], ] tags: List[str] title: str + class CustomisedSchemaExtra(TypedDict): ui: UIConfig -class InvocationConfig(BaseModel.Config): +class InvocationConfig(BaseConfig): """Customizes pydantic's BaseModel.Config class for use by Invocations. Provide `schema_extra` a `ui` dict to add hints for generated UIs. diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 891f217317..33bde42d69 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -4,13 +4,16 @@ from typing import Literal import numpy as np from pydantic import Field, validator +from invokeai.app.models.image import ImageField from invokeai.app.util.misc import SEED_MAX, get_random_seed from .baseinvocation import ( BaseInvocation, + InvocationConfig, InvocationContext, BaseInvocationOutput, + UIConfig, ) @@ -22,6 +25,7 @@ class IntCollectionOutput(BaseInvocationOutput): # Outputs collection: list[int] = Field(default=[], description="The int collection") + class FloatCollectionOutput(BaseInvocationOutput): """A collection of floats""" @@ -31,6 +35,18 @@ class FloatCollectionOutput(BaseInvocationOutput): collection: list[float] = Field(default=[], description="The float collection") +class ImageCollectionOutput(BaseInvocationOutput): + """A collection of images""" + + type: Literal["image_collection"] = "image_collection" + + # Outputs + collection: list[ImageField] = Field(default=[], description="The output images") + + class Config: + schema_extra = {"required": ["type", "collection"]} + + class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -92,3 +108,27 @@ class RandomRangeInvocation(BaseInvocation): return IntCollectionOutput( collection=list(rng.integers(low=self.low, high=self.high, size=self.size)) ) + + +class ImageCollectionInvocation(BaseInvocation): + """Load a collection of images and provide it as output.""" + + # fmt: off + type: Literal["image_collection"] = "image_collection" + + # Inputs + images: list[ImageField] = Field( + default=[], description="The image collection to load" + ) + # fmt: on + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + return ImageCollectionOutput(collection=self.images) + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "type_hints": { + "images": "image_collection", + } + }, + } diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 8c6b23944c..4850b9670d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,27 +1,28 @@ -from typing import Literal, Optional, Union -from pydantic import BaseModel, Field -from contextlib import ExitStack import re +from contextlib import ExitStack +from typing import List, Literal, Optional, Union -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField +import torch +from compel import Compel +from compel.prompt_parser import (Blend, Conjunction, + CrossAttentionControlSubstitute, + FlattenedPrompt, Fragment) +from pydantic import BaseModel, Field -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management.lora import ModelPatcher - -from compel import Compel -from compel.prompt_parser import ( - Blend, - CrossAttentionControlSubstitute, - FlattenedPrompt, - Fragment, Conjunction, -) +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .model import ClipField class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field( + default=None, description="The name of conditioning data") + class Config: schema_extra = {"required": ["conditioning_name"]} @@ -51,83 +52,92 @@ class CompelInvocation(BaseInvocation): "title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": { - "model": "model" + "model": "model" } }, } + @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder: - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + def _lora_loader(): + for lora in self.clip.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - ) - except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - - compel = Compel( - tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) - - conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + except ModelNotFoundException: + # print(e) + #import traceback + #print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") - if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ + text_encoder_info as text_encoder: - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=ti_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: ) + conjunction = Compel.parse_prompt_string(self.prompt) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + + if context.services.configuration.log_tokenization: + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object( + prompt) + + # TODO: long prompt support + # if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count( + tokenizer, conjunction), + cross_attention_control_args=options.get( + "cross_attention_control", None),) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False -) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], + truncate_if_too_long=False) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( @@ -146,13 +156,13 @@ def get_max_token_count( ) else: return len( - get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) - ) + get_tokens_for_prompt_object( + tokenizer, prompt, truncate_if_too_long)) def get_tokens_for_prompt_object( tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> [str]: +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError( "Blend is not supported here - you need to get tokens for each of its .children" @@ -181,7 +191,7 @@ def log_tokenization_for_conjunction( ): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): - if len(c.prompts)>1: + if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix @@ -236,7 +246,8 @@ def log_tokenization_for_prompt_object( ) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a9576a2fe1..3e691c934e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,18 +4,17 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops - -from pydantic import BaseModel, Field, validator import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field, validator from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.image_util.seamless import configure_model_padding +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, @@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") + latents_name: Optional[str] = Field( + default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} + class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off @@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) SAMPLER_NAME_VALUES = Literal[ @@ -70,16 +72,19 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( + scheduler_name, SCHEDULER_MAP['ddim']) + orig_scheduler_info = context.services.model_manager.get_model( + **scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config - + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = {**scheduler_config, ** + scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: + self, context: InvocationContext, source_node_id: str, + intermediate_state: PipelineIntermediateState) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning_data( + self, context: InvocationContext, scheduler) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get( + self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get( + self.negative_conditioning.conditioning_name) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation): guidance_scale=self.cfg_scale, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( - threshold=0.0,#threshold, - warmup=0.2,#warmup, - h_symmetry_time_pct=None,#h_symmetry_time_pct, - v_symmetry_time_pct=None#v_symmetry_time_pct, + threshold=0.0, # threshold, + warmup=0.2, # warmup, + h_symmetry_time_pct=None, # h_symmetry_time_pct, + v_symmetry_time_pct=None # v_symmetry_time_pct, ), ) @@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation): scheduler, # for ddim scheduler - eta=0.0, #ddim_eta + eta=0.0, # ddim_eta # for ancestral and sde schedulers generator=torch.Generator(device=uc.device).manual_seed(0), ) return conditioning_data - def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + def create_pipeline( + self, unet, scheduler) -> StableDiffusionGeneratorPipeline: # TODO: - #configure_model_padding( + # configure_model_padding( # unet, # self.seamless, # self.seamless_axes, - #) + # ) class FakeVae: class FakeVaeConfig: def __init__(self): self.block_out_channels = [0] - + def __init__(self): self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... + vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, @@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation): requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", ) - + def prep_control_data( self, context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + # really only need model for dtype and device + model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], do_classifier_free_guidance: bool = True, @@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation): print("Using HF model subfolders") print(" control_name: ", control_name) print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained(control_name, - subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_name, subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to( + model.device) else: - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.services.images.get_pil_image( + control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation): dtype=control_model.dtype, control_mode=control_info.control_mode, ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - ) + control_item = ControlNetData( + model=control_model, image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet: + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) - + pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback, - ) + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) + class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + latents: Optional[LatentsField] = Field( + description="The latents to use as a base image") + strength: float = Field( + default=0.7, ge=0, le=1, + description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - with unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, @@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - + control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype - ) + latent, device=unet.device, dtype=latent.dtype) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback - ) + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + latents: Optional[LatentsField] = Field( + description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize + image = (image / 2 + 0.5).clamp(0, 1) # denormalize # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) -LATENTS_INTERPOLATION_MODE = Literal[ - "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" -] + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", + "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to resize") + width: int = Field( + ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field( + ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) resized_latents = torch.nn.functional.interpolate( - latents, - size=(self.height // 8, self.width // 8), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, size=(self.height // 8, self.width // 8), + mode=self.mode, antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to scale") + scale_factor: float = Field( + gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # resizing resized_latents = torch.nn.functional.interpolate( - latents, - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Union[ImageField, None] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 760fa08a12..17297ba417 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,31 +1,38 @@ -from typing import Literal, Optional, Union, List -from pydantic import BaseModel, Field import copy +from typing import List, Literal, Optional, Union -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from pydantic import BaseModel, Field -from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.model_management import BaseModelType, ModelType, SubModelType +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) + class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(description="Info to load submodel") + submodel: Optional[SubModelType] = Field( + default=None, description="Info to load submodel" + ) + class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") + class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class ClipField(BaseModel): tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") @@ -34,43 +41,48 @@ class VaeField(BaseModel): class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output"] = "model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on -class PipelineModelField(BaseModel): - """Pipeline model field""" +class MainModelField(BaseModel): + """Main model field""" model_name: str = Field(description="Name of the model") base_model: BaseModelType = Field(description="Base model") -class PipelineModelLoaderInvocation(BaseInvocation): - """Loads a pipeline model, outputting its submodels.""" +class LoRAModelField(BaseModel): + """LoRA model field""" - type: Literal["pipeline_model_loader"] = "pipeline_model_loader" + model_name: str = Field(description="Name of the LoRA model") + base_model: BaseModelType = Field(description="Base model") - model: PipelineModelField = Field(description="The model to load") + +class MainModelLoaderInvocation(BaseInvocation): + """Loads a main model, outputting its submodels.""" + + type: Literal["main_model_loader"] = "main_model_loader" + + model: MainModelField = Field(description="The model to load") # TODO: precision? # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { + "title": "Model Loader", "tags": ["model", "loader"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation): ) """ - return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.Vae, ), - ) + ), ) + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" unet: Optional[UNetField] = Field(default=None, description="UNet submodel") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - #fmt: on + # fmt: on + class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora_name: str = Field(description="Lora model name") + lora: Union[LoRAModelField, None] = Field( + default=None, description="Lora model name" + ) weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") clip: Optional[ClipField] = Field(description="Clip model for applying lora") - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "Lora Loader", + "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, + }, + } - # TODO: ui rewrite - base_model = BaseModelType.StableDiffusion1 + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") + + base_model = self.lora.base_model + lora_name = self.lora.model_name if not context.services.model_manager.model_exists( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {self.lora_name}!") + raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + if self.unet is not None and any( + lora.model_name == lora_name for lora in self.unet.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + if self.clip is not None and any( + lora.model_name == lora_name for lora in self.clip.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() @@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation): return output + +class VAEModelField(BaseModel): + """Vae model field""" + + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") + + +class VaeLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + # fmt: off + type: Literal["vae_loader_output"] = "vae_loader_output" + + vae: VaeField = Field(default=None, description="Vae model") + # fmt: on + + +class VaeLoaderInvocation(BaseInvocation): + """Loads a VAE model, outputting a VaeLoaderOutput""" + + type: Literal["vae_loader"] = "vae_loader" + + vae_model: VAEModelField = Field(description="The VAE to load") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "VAE Loader", + "tags": ["vae", "loader"], + "type_hints": {"vae_model": "vae_model"}, + }, + } + + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: + base_model = self.vae_model.base_model + model_name = self.vae_model.model_name + model_type = ModelType.Vae + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=model_name, + model_type=model_type, + ): + raise Exception(f"Unkown vae name: {model_name}!") + return VaeLoaderOutput( + vae=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + ) + ) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index e0f1ceeb25..e7f817fc0a 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings): upcase_environ = dict() for key,value in os.environ.items(): upcase_environ[key.upper()] = value - + fields = cls.__fields__ cls.argparse_groups = {} - + for name, field in fields.items(): if name not in cls._excluded(): current_default = field.default @@ -348,7 +348,7 @@ setting environment variables INVOKEAI_. ''' singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_init: ClassVar[Dict] = None - + #fmt: off type: Literal["InvokeAI"] = "InvokeAI" host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') @@ -367,7 +367,8 @@ setting environment variables INVOKEAI_. always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') @@ -385,9 +386,9 @@ setting environment variables INVOKEAI_. outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - + model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') - + log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") @@ -396,7 +397,7 @@ setting environment variables INVOKEAI_. def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False): ''' - Update settings with contents of init file, environment, and + Update settings with contents of init file, environment, and command-line settings. :param conf: alternate Omegaconf dictionary object :param argv: aternate sys.argv list @@ -411,7 +412,7 @@ setting environment variables INVOKEAI_. except: pass InvokeAISettings.initconf = conf - + # parse args again in order to pick up settings in configuration file super().parse_args(argv) @@ -431,7 +432,7 @@ setting environment variables INVOKEAI_. cls.singleton_config = cls(**kwargs) cls.singleton_init = kwargs return cls.singleton_config - + @property def root_path(self)->Path: ''' diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 10d1d91920..4e1da3b040 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.images import ImageServiceABC - from invokeai.backend import ModelManager + from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.restoration_services import RestorationServices @@ -22,46 +22,47 @@ class InvocationServices: """Services that can be used by invocations""" # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - events: "EventServiceBase" - latents: "LatentsStorageBase" - queue: "InvocationQueueABC" - model_manager: "ModelManager" - restoration: "RestorationServices" - configuration: "InvokeAISettings" - images: "ImageServiceABC" - boards: "BoardServiceABC" board_images: "BoardImagesServiceABC" - graph_library: "ItemStorageABC"["LibraryGraph"] + boards: "BoardServiceABC" + configuration: "InvokeAISettings" + events: "EventServiceBase" graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] + graph_library: "ItemStorageABC"["LibraryGraph"] + images: "ImageServiceABC" + latents: "LatentsStorageBase" + logger: "Logger" + model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + queue: "InvocationQueueABC" + restoration: "RestorationServices" def __init__( self, - model_manager: "ModelManager", - events: "EventServiceBase", - logger: "Logger", - latents: "LatentsStorageBase", - images: "ImageServiceABC", - boards: "BoardServiceABC", board_images: "BoardImagesServiceABC", - queue: "InvocationQueueABC", - graph_library: "ItemStorageABC"["LibraryGraph"], - graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], - processor: "InvocationProcessorABC", - restoration: "RestorationServices", + boards: "BoardServiceABC", configuration: "InvokeAISettings", + events: "EventServiceBase", + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], + graph_library: "ItemStorageABC"["LibraryGraph"], + images: "ImageServiceABC", + latents: "LatentsStorageBase", + logger: "Logger", + model_manager: "ModelManagerServiceBase", + processor: "InvocationProcessorABC", + queue: "InvocationQueueABC", + restoration: "RestorationServices", ): - self.model_manager = model_manager - self.events = events - self.logger = logger - self.latents = latents - self.images = images - self.boards = boards self.board_images = board_images - self.queue = queue - self.graph_library = graph_library - self.graph_execution_manager = graph_execution_manager - self.processor = processor - self.restoration = restoration - self.configuration = configuration self.boards = boards + self.boards = boards + self.configuration = configuration + self.events = events + self.graph_execution_manager = graph_execution_manager + self.graph_library = graph_library + self.images = images + self.latents = latents + self.logger = logger + self.model_manager = model_manager + self.processor = processor + self.queue = queue + self.restoration = restoration diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8b46b17ad0..455d9d021f 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -33,13 +33,13 @@ class ModelManagerServiceBase(ABC): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ pass - + @abstractmethod def get_model( self, @@ -50,8 +50,8 @@ class ModelManagerServiceBase(ABC): node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) + """Retrieve the indicated model with name and type. + submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" pass @@ -115,8 +115,8 @@ class ModelManagerServiceBase(ABC): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ pass @@ -129,12 +129,35 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ pass + @abstractmethod + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + pass + @abstractmethod def commit(self, conf_file: Path = None) -> None: """ @@ -153,7 +176,7 @@ class ModelManagerService(ModelManagerServiceBase): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. @@ -183,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(config,'max_cache_size') \ else config.max_loaded_models * 2.5 + logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + sequential_offload = config.sequential_guidance self.mgr = ModelManager( @@ -238,7 +263,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, model_info=model_info ) - + return model_info def model_exists( @@ -291,8 +316,8 @@ class ModelManagerService(ModelManagerServiceBase): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) @@ -305,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ self.mgr.del_model(model_name, base_model, model_type) @@ -360,4 +385,25 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): return self.mgr.logger - + + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index a0104bef25..0952a15cf7 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -430,13 +430,13 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib max_height=len(PRECISION_CHOICES) + 1, scroll_exit=True, ) - self.max_loaded_models = self.add_widget_intelligent( + self.max_cache_size = self.add_widget_intelligent( IntTitleSlider, - name="Number of models to cache in CPU memory (each will use 2-4 GB!)", - value=old_opts.max_loaded_models, - out_of=10, - lowest=1, - begin_entry_at=4, + name="Size of the RAM cache used for fast model switching (GB)", + value=old_opts.max_cache_size, + out_of=20, + lowest=3, + begin_entry_at=6, scroll_exit=True, ) self.nextrely += 1 @@ -539,7 +539,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license "outdir", "nsfw_checker", "free_gpu_mem", - "max_loaded_models", + "max_cache_size", "xformers_enabled", "always_use_cpu", ]: @@ -555,9 +555,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license new_opts.license_acceptance = self.license_acceptance.value new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] - # widget library workaround to make max_loaded_models an int rather than a float - new_opts.max_loaded_models = int(new_opts.max_loaded_models) - return new_opts diff --git a/invokeai/backend/install/legacy_arg_parsing.py b/invokeai/backend/install/legacy_arg_parsing.py index 4a58ff8336..684c50c77d 100644 --- a/invokeai/backend/install/legacy_arg_parsing.py +++ b/invokeai/backend/install/legacy_arg_parsing.py @@ -4,6 +4,8 @@ import argparse import shlex from argparse import ArgumentParser +# note that this includes both old sampler names and new scheduler names +# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init SAMPLER_CHOICES = [ "ddim", "ddpm", @@ -27,6 +29,15 @@ SAMPLER_CHOICES = [ "dpmpp_sde", "dpmpp_sde_k", "unipc", + "k_dpm_2_a", + "k_dpm_2", + "k_dpmpp_2_a", + "k_dpmpp_2", + "k_euler_a", + "k_euler", + "k_heun", + "k_lms", + "plms", ] PRECISION_CHOICES = [ diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index c8e024f484..6f9cee6246 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -76,6 +76,10 @@ class MigrateTo3(object): Create a unique name for a model for use within models.yaml. ''' done = False + + # some model names have slashes in them, which really screws things up + name = name.replace('/','_') + key = ModelManager.create_key(name,info.base_type,info.model_type) unique_name = key counter = 1 @@ -219,11 +223,12 @@ class MigrateTo3(object): repo_id = 'openai/clip-vit-large-patch14' self._migrate_pretrained(CLIPTokenizer, repo_id= repo_id, - dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', + dest= target_dir / 'clip-vit-large-patch14', **kwargs) self._migrate_pretrained(CLIPTextModel, repo_id = repo_id, - dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', + dest = target_dir / 'clip-vit-large-patch14', + force = True, **kwargs) # sd-2 @@ -287,21 +292,21 @@ class MigrateTo3(object): def _model_probe_to_path(self, info: ModelProbeInfo)->Path: return Path(self.dest_models, info.base_type.value, info.model_type.value) - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): - if dest.exists(): + def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs): + if dest.exists() and not force: logger.info(f'Skipping existing {dest}') return model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest) + self._save_pretrained(model, dest, overwrite=force) - def _save_pretrained(self, model, dest: Path): - if dest.exists(): - logger.info(f'Skipping existing {dest}') - return + def _save_pretrained(self, model, dest: Path, overwrite: bool=False): model_name = dest.name - download_path = dest.with_name(f'{model_name}.downloading') - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) + if overwrite: + model.save_pretrained(dest, safe_serialization=True) + else: + download_path = dest.with_name(f'{model_name}.downloading') + model.save_pretrained(download_path, safe_serialization=True) + download_path.replace(dest) def _download_vae(self, repo_id: str, subfolder:str=None)->Path: vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) @@ -569,8 +574,10 @@ script, which will perform a full upgrade in place.""" dest_directory = args.dest_directory assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" - assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" - assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." + + # TODO: revisit + # assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" + # assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." do_migrate(root_directory,dest_directory) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 1c2f4d2fc1..86a922c05a 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -18,7 +18,7 @@ from tqdm import tqdm import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType +from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.util import download_with_resume from ..util.logging import InvokeAILogger @@ -166,17 +166,22 @@ class ModelInstall(object): # add requested models for path in selections.install_models: logger.info(f'Installing {path} [{job}/{jobs}]') - self.heuristic_install(path) + self.heuristic_import(path) job += 1 self.mgr.commit() - def heuristic_install(self, + def heuristic_import(self, model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None)->Set[Path]: + models_installed: Set[Path]=None)->Dict[str, AddModelResult]: + ''' + :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL + :param models_installed: Set of installed models, used for recursive invocation + Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. + ''' if not models_installed: - models_installed = set() + models_installed = dict() # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url @@ -185,24 +190,27 @@ class ModelInstall(object): try: # checkpoint file, or similar if path.is_file(): - models_installed.add(self._install_path(path)) + models_installed.update(self._install_path(path)) # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - models_installed.add(self._install_path(path)) + elif path.is_dir() and any([(path/x).exists() for x in \ + {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} + ] + ): + models_installed.update(self._install_path(path)) # recursive scan elif path.is_dir(): for child in path.iterdir(): - self.heuristic_install(child, models_installed=models_installed) + self.heuristic_import(child, models_installed=models_installed) # huggingface repo elif len(str(path).split('/')) == 2: - models_installed.add(self._install_repo(str(path))) + models_installed.update(self._install_repo(str(path))) # a URL elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): - models_installed.add(self._install_url(model_path_id_or_url)) + models_installed.update(self._install_url(model_path_id_or_url)) else: logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') @@ -214,24 +222,25 @@ class ModelInstall(object): # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: + def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: try: - # logger.debug(f'Probing {path}') + model_result = None info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) - model_name = path.stem if info.format=='checkpoint' else path.name + model_name = path.stem if path.is_file() else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') attributes = self._make_attributes(path,info) - self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) + model_result = self.mgr.add_model(model_name = model_name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes, + ) except Exception as e: logger.warning(f'{str(e)} Skipping registration.') - return path + return {} + return {str(path): model_result} - def _install_url(self, url: str)->Path: + def _install_url(self, url: str)->dict: # copy to a staging area, probe, import and delete with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) @@ -244,7 +253,7 @@ class ModelInstall(object): # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->Path: + def _install_repo(self, repo_id: str)->dict: hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index fb3b20a20a..34e0b15728 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,7 +1,7 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, ModelInfo +from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 1eeee92fb7..e3e64940de 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from .model_manager import ModelManager -from .model_cache import ModelCache +from picklescan.scanner import scan_file_path from .models import BaseModelType, ModelVariantType try: @@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint = load_file(checkpoint_path) else: if scan_needed: - ModelCache.scan_model(checkpoint_path, checkpoint_path) + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." checkpoint = torch.load(checkpoint_path) # sometimes there is a state_dict key and sometimes not diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 6cfcb8dd8d..d8ecdf81c2 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,18 +1,15 @@ from __future__ import annotations import copy -from pathlib import Path from contextlib import contextmanager -from typing import Optional, Dict, Tuple, Any +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union, List import torch -from safetensors.torch import load_file -from torch.utils.hooks import RemovableHandle - -from diffusers.models import UNet2DConditionModel -from transformers import CLIPTextModel - from compel.embeddings_provider import BaseTextualInversionManager +from diffusers.models import UNet2DConditionModel +from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer class LoRALayerBase: #rank: Optional[int] @@ -124,8 +121,8 @@ class LoRALayer(LoRALayerBase): def get_weight(self): if self.mid is not None: - up = self.up.reshape(up.shape[0], up.shape[1]) - down = self.down.reshape(up.shape[0], up.shape[1]) + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) @@ -411,7 +408,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format print( - f">> Encountered unknown lora layer module in {self.name}: {layer_key}" + f">> Encountered unknown lora layer module in {model.name}: {layer_key}" ) return @@ -539,9 +536,10 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - with torch.autocast(device_type="cpu"): - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + #with torch.autocast(device_type="cpu"): + layer.to(dtype=torch.float32) + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris @@ -655,6 +653,9 @@ class TextualInversionModel: else: result.embedding = next(iter(state_dict.values())) + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + if not isinstance(result.embedding, torch.Tensor): raise ValueError(f"Invalid embeddings file: {file_path.name}") diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 77b6ac5115..4155edb686 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -8,7 +8,7 @@ The cache returns context manager generators designed to load the model into the GPU within the context, and unload outside the context. Use like this: - cache = ModelCache(max_models_cached=6) + cache = ModelCache(max_cache_size=7.5) with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, cache.get_model('stabilityai/stable-diffusion-2') as SD2: do_something_in_GPU(SD1,SD2) @@ -91,7 +91,7 @@ class ModelCache(object): logger: types.ModuleType = logger ): ''' - :param max_models: Maximum number of models to cache in CPU RAM [4] + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] @@ -100,8 +100,6 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' #max_cache_size = 9999 - execution_device = torch.device('cuda') - self.model_infos: Dict[str, ModelBase] = dict() self.lazy_offloading = lazy_offloading #self.sequential_offload: bool=sequential_offload @@ -128,16 +126,6 @@ class ModelCache(object): key += f":{submodel_type}" return key - #def get_model( - # self, - # repo_id_or_path: Union[str, Path], - # model_type: ModelType = ModelType.Diffusers, - # subfolder: Path = None, - # submodel: ModelType = None, - # revision: str = None, - # attach_model_part: Tuple[ModelType, str] = (None, None), - # gpu_load: bool = True, - #) -> ModelLocker: # ?? what does it return def _get_model_info( self, model_path: str, diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 7dc174bbce..db8a691d29 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -233,14 +233,14 @@ import hashlib import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Optional, List, Tuple, Union, Set, Callable, types +from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree import torch from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel +from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig @@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, - ModelConfigBase, + ModelConfigBase, ModelNotFoundException, ) # We are only starting to number the config file with release 3. @@ -278,8 +278,13 @@ class InvalidModelError(Exception): "Raised when an invalid model is requested" pass -MAX_CACHE_SIZE = 6.0 # GB +class AddModelResult(BaseModel): + name: str = Field(description="The name of the model after import") + model_type: ModelType = Field(description="The type of model") + base_model: BaseModelType = Field(description="The base model") + config: ModelConfigBase = Field(description="The configuration of the model") +MAX_CACHE_SIZE = 6.0 # GB class ConfigMeta(BaseModel): version: str @@ -404,7 +409,7 @@ class ModelManager(object): if model_key not in self.models: self.scan_models_directory(base_model=base_model, model_type=model_type) if model_key not in self.models: - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self.models[model_key] model_path = self.app_config.root_path / model_config.path @@ -416,14 +421,14 @@ class ModelManager(object): else: self.models.pop(model_key, None) - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") # vae/movq override # TODO: if submodel_type is not None and hasattr(model_config, submodel_type): override_path = getattr(model_config, submodel_type) if override_path: - model_path = override_path + model_path = self.app_config.root_path / override_path model_type = submodel_type submodel_type = None model_class = MODEL_CLASSES[base_model][model_type] @@ -431,6 +436,7 @@ class ModelManager(object): # TODO: path # TODO: is it accurate to use path as id dst_convert_path = self._get_model_cache_path(model_path) + model_path = model_class.convert_if_required( base_model=base_model, model_path=str(model_path), # TODO: refactor str/Path types logic @@ -570,13 +576,16 @@ class ModelManager(object): model_type: ModelType, model_attributes: dict, clobber: bool = False, - ) -> None: + ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. On a successful update, the config will be changed in memory and the method will return True. Will fail with an assertion error if provided attributes are incorrect or the model name is missing. + + The returned dict has the same format as the dict returned by + model_info(). """ model_class = MODEL_CLASSES[base_model][model_type] @@ -600,12 +609,18 @@ class ModelManager(object): old_model_cache.unlink() # remove in-memory cache - # note: it not garantie to release memory(model can has other references) + # note: it not guaranteed to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) for cache_id in cache_ids: self.cache.uncache_model(cache_id) self.models[model_key] = model_config + return AddModelResult( + name = model_name, + model_type = model_type, + base_model = base_model, + config = model_config, + ) def search_models(self, search_folder): self.logger.info(f"Finding Models In: {search_folder}") @@ -716,19 +731,19 @@ class ModelManager(object): if model_path.is_relative_to(self.app_config.root_path): model_path = model_path.relative_to(self.app_config.root_path) - try: - model_config: ModelConfigBase = model_class.probe_config(str(model_path)) - self.models[model_key] = model_config - new_models_found = True - except NotImplementedError as e: - self.logger.warning(e) + try: + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) + self.models[model_key] = model_config + new_models_found = True + except NotImplementedError as e: + self.logger.warning(e) imported_models = self.autoimport() if (new_models_found or imported_models) and self.config_path: self.commit() - def autoimport(self)->set[Path]: + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. ''' @@ -741,7 +756,6 @@ class ModelManager(object): prediction_type_helper = ask_user_for_prediction_type, ) - installed = set() scanned_dirs = set() config = self.app_config @@ -755,13 +769,14 @@ class ModelManager(object): continue self.logger.info(f'Scanning {autodir} for models to import') + installed = dict() autodir = self.app_config.root_path / autodir if not autodir.exists(): continue items_scanned = 0 - new_models_found = set() + new_models_found = dict() for root, dirs, files in os.walk(autodir): items_scanned += len(dirs) + len(files) @@ -770,8 +785,8 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: scanned_dirs.add(path) continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - new_models_found.update(installer.heuristic_install(path)) + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + new_models_found.update(installer.heuristic_import(path)) scanned_dirs.add(path) for f in files: @@ -779,7 +794,8 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: continue if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - new_models_found.update(installer.heuristic_install(path)) + import_result = installer.heuristic_import(path) + new_models_found.update(import_result) self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') installed.update(new_models_found) @@ -789,7 +805,7 @@ class ModelManager(object): def heuristic_import(self, items_to_import: Set[str], prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Set[str]: + )->Dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -802,17 +818,20 @@ class ModelManager(object): generally impossible to do this programmatically, so the prediction_type_helper usually asks the user to choose. + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. ''' # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall - successfully_installed = set() + successfully_installed = dict() installer = ModelInstall(config = self.app_config, prediction_type_helper = prediction_type_helper, model_manager = self) for thing in items_to_import: try: - installed = installer.heuristic_install(thing) + installed = installer.heuristic_import(thing) successfully_installed.update(installed) except Exception as e: self.logger.warning(f'{thing} could not be imported: {str(e)}') diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 2828cc7ab1..eef3292d6d 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -78,7 +78,6 @@ class ModelProbe(object): format_type = 'diffusers' if model_path.is_dir() else 'checkpoint' else: format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' - model_info = None try: model_type = cls.get_model_type_from_folder(model_path, model) \ @@ -105,7 +104,7 @@ class ModelProbe(object): ) else 512, ) except Exception: - return None + raise return model_info @@ -127,6 +126,8 @@ class ModelProbe(object): return ModelType.Vae elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): return ModelType.Lora + elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): + return ModelType.Lora elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): return ModelType.ControlNet elif key in {"emb_params", "string_to_param"}: @@ -137,7 +138,7 @@ class ModelProbe(object): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion - raise ValueError("Unable to determine model type") + raise ValueError(f"Unable to determine model type for {model_path}") @classmethod def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: @@ -167,7 +168,7 @@ class ModelProbe(object): return type # give up - raise ValueError("Unable to determine model type") + raise ValueError("Unable to determine model type for {folder_path}") @classmethod def _scan_and_load_checkpoint(cls,model_path: Path)->dict: diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 87b0ad3c4e..00630eef62 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -2,7 +2,7 @@ import inspect from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin -from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings +from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel from .lora import LoRAModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index afa62b2e4f..57c02bce76 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -15,6 +15,9 @@ from contextlib import suppress from pydantic import BaseModel, Field from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union +class ModelNotFoundException(Exception): + pass + class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 9a032218f0..4dcdbb24ba 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -8,6 +8,7 @@ from .base import ( ModelType, SubModelType, classproperty, + ModelNotFoundException, ) # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw @@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase): if child_type is not None: raise Exception("There is no child models in textual inversion") + checkpoint_path = self.model_path + if os.path.isdir(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") + + if not os.path.exists(checkpoint_path): + raise ModelNotFoundException() + model = TextualInversionModelRaw.from_checkpoint( - file_path=self.model_path, + file_path=checkpoint_path, dtype=torch_dtype, ) diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 33ef114912..f3ebcb22be 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace): # this is where the TUI is called else: - # needed because the torch library is loaded, even though we don't use it - # currently commented out because it has started generating errors (?) - # torch.multiprocessing.set_start_method("spawn") + # needed to support the probe() method running under a subprocess + torch.multiprocessing.set_start_method("spawn") # the third argument is needed in the Windows 11 environment in # order to launch and resize a console window running this program diff --git a/invokeai/frontend/web/.eslintrc.js b/invokeai/frontend/web/.eslintrc.js index b1a2b6a7e4..34db9d466b 100644 --- a/invokeai/frontend/web/.eslintrc.js +++ b/invokeai/frontend/web/.eslintrc.js @@ -36,6 +36,12 @@ module.exports = { ], 'prettier/prettier': ['error', { endOfLine: 'auto' }], '@typescript-eslint/ban-ts-comment': 'warn', + '@typescript-eslint/no-empty-interface': [ + 'error', + { + allowSingleExtends: true, + }, + ], }, settings: { react: { diff --git a/invokeai/frontend/web/dist/index.html b/invokeai/frontend/web/dist/index.html index 6c4c1c21ae..a0adc1d803 100644 --- a/invokeai/frontend/web/dist/index.html +++ b/invokeai/frontend/web/dist/index.html @@ -12,7 +12,7 @@ margin: 0; } - + diff --git a/invokeai/frontend/web/dist/locales/en.json b/invokeai/frontend/web/dist/locales/en.json index 7a73bae411..6fb56a2979 100644 --- a/invokeai/frontend/web/dist/locales/en.json +++ b/invokeai/frontend/web/dist/locales/en.json @@ -24,16 +24,13 @@ }, "common": { "hotkeysLabel": "Hotkeys", - "themeLabel": "Theme", + "darkMode": "Dark Mode", + "lightMode": "Light Mode", "languagePickerLabel": "Language", "reportBugLabel": "Report Bug", "githubLabel": "Github", "discordLabel": "Discord", "settingsLabel": "Settings", - "darkTheme": "Dark", - "lightTheme": "Light", - "greenTheme": "Green", - "oceanTheme": "Ocean", "langArabic": "العربية", "langEnglish": "English", "langDutch": "Nederlands", @@ -55,6 +52,7 @@ "unifiedCanvas": "Unified Canvas", "linear": "Linear", "nodes": "Node Editor", + "modelmanager": "Model Manager", "postprocessing": "Post Processing", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "postProcessing": "Post Processing", @@ -336,6 +334,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "vae": "VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", @@ -351,6 +350,7 @@ "scanForModels": "Scan For Models", "addManually": "Add Manually", "manual": "Manual", + "baseModel": "Base Model", "name": "Name", "nameValidationMsg": "Enter a name for your model", "description": "Description", @@ -363,6 +363,7 @@ "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", "vaeLocationValidationMsg": "Path to where your VAE is located.", + "variant": "Variant", "vaeRepoID": "VAE Repo ID", "vaeRepoIDValidationMsg": "Online repository of your VAE", "width": "Width", @@ -524,7 +525,8 @@ "initialImage": "Initial Image", "showOptionsPanel": "Show Options Panel", "hidePreview": "Hide Preview", - "showPreview": "Show Preview" + "showPreview": "Show Preview", + "controlNetControlMode": "Control Mode" }, "settings": { "models": "Models", @@ -547,7 +549,8 @@ "general": "General", "generation": "Generation", "ui": "User Interface", - "availableSchedulers": "Available Schedulers" + "favoriteSchedulers": "Favorite Schedulers", + "favoriteSchedulersPlaceholder": "No schedulers favorited" }, "toast": { "serverError": "Server Error", diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 786a721d5c..cd86bfdbe8 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -67,6 +67,7 @@ "@fontsource-variable/inter": "^5.0.3", "@fontsource/inter": "^5.0.3", "@mantine/core": "^6.0.14", + "@mantine/form": "^6.0.15", "@mantine/hooks": "^6.0.14", "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", @@ -82,7 +83,7 @@ "konva": "^9.2.0", "lodash-es": "^4.17.21", "nanostores": "^0.9.2", - "openapi-fetch": "^0.4.0", + "openapi-fetch": "0.4.0", "overlayscrollbars": "^2.2.0", "overlayscrollbars-react": "^0.5.0", "patch-package": "^7.0.0", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1b3b790222..9cf1e0bc48 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -52,6 +52,8 @@ "unifiedCanvas": "Unified Canvas", "linear": "Linear", "nodes": "Node Editor", + "batch": "Batch Manager", + "modelmanager": "Model Manager", "postprocessing": "Post Processing", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "postProcessing": "Post Processing", @@ -333,6 +335,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "vae": "VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", @@ -348,6 +351,7 @@ "scanForModels": "Scan For Models", "addManually": "Add Manually", "manual": "Manual", + "baseModel": "Base Model", "name": "Name", "nameValidationMsg": "Enter a name for your model", "description": "Description", @@ -360,6 +364,7 @@ "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", "vaeLocationValidationMsg": "Path to where your VAE is located.", + "variant": "Variant", "vaeRepoID": "VAE Repo ID", "vaeRepoIDValidationMsg": "Online repository of your VAE", "width": "Width", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 5b3cf5925f..f43c8fc5c0 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -1,67 +1,40 @@ -import { Box, Flex, Grid, Portal } from '@chakra-ui/react'; +import { Flex, Grid, Portal } from '@chakra-ui/react'; import { useLogger } from 'app/logging/useLogger'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { PartialAppConfig } from 'app/types/invokeai'; import ImageUploader from 'common/components/ImageUploader'; -import Loading from 'common/components/Loading/Loading'; import GalleryDrawer from 'features/gallery/components/GalleryPanel'; +import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal'; import Lightbox from 'features/lightbox/components/Lightbox'; import SiteHeader from 'features/system/components/SiteHeader'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady'; import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import InvokeTabs from 'features/ui/components/InvokeTabs'; import ParametersDrawer from 'features/ui/components/ParametersDrawer'; -import { AnimatePresence, motion } from 'framer-motion'; import i18n from 'i18n'; -import { ReactNode, memo, useCallback, useEffect, useState } from 'react'; -import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; +import { ReactNode, memo, useEffect } from 'react'; +import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; -import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; -import { useListModelsQuery } from 'services/api/endpoints/models'; -import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal'; const DEFAULT_CONFIG = {}; interface Props { config?: PartialAppConfig; headerComponent?: ReactNode; - setIsReady?: (isReady: boolean) => void; } -const App = ({ - config = DEFAULT_CONFIG, - headerComponent, - setIsReady, -}: Props) => { +const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { const language = useAppSelector(languageSelector); const log = useLogger(); const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; - const isApplicationReady = useIsApplicationReady(); - - const { data: pipelineModels } = useListModelsQuery({ - model_type: 'main', - }); - const { data: controlnetModels } = useListModelsQuery({ - model_type: 'controlnet', - }); - const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' }); - const { data: loraModels } = useListModelsQuery({ model_type: 'lora' }); - const { data: embeddingModels } = useListModelsQuery({ - model_type: 'embedding', - }); - - const [loadingOverridden, setLoadingOverridden] = useState(false); - const dispatch = useAppDispatch(); useEffect(() => { @@ -73,27 +46,6 @@ const App = ({ dispatch(configChanged(config)); }, [dispatch, config, log]); - const handleOverrideClicked = useCallback(() => { - setLoadingOverridden(true); - }, []); - - useEffect(() => { - if (isApplicationReady && setIsReady) { - setIsReady(true); - } - - if (isApplicationReady) { - // TODO: This is a jank fix for canvas not filling the screen on first load - setTimeout(() => { - dispatch(requestCanvasRescale()); - }, 200); - } - - return () => { - setIsReady && setIsReady(false); - }; - }, [dispatch, isApplicationReady, setIsReady]); - return ( <> @@ -123,33 +75,6 @@ const App = ({ - - - {!isApplicationReady && !loadingOverridden && ( - - - - - - - )} - - diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx new file mode 100644 index 0000000000..bf66c0ee08 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx @@ -0,0 +1,122 @@ +import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { memo } from 'react'; +import { TypesafeDraggableData } from './typesafeDnd'; + +type OverlayDragImageProps = { + dragData: TypesafeDraggableData | null; +}; + +const BOX_SIZE = 28; + +const STYLES: ChakraProps['sx'] = { + w: BOX_SIZE, + h: BOX_SIZE, + maxW: BOX_SIZE, + maxH: BOX_SIZE, + shadow: 'dark-lg', + borderRadius: 'lg', + borderWidth: 2, + borderStyle: 'dashed', + borderColor: 'base.100', + opacity: 0.5, + bg: 'base.800', + color: 'base.50', + _dark: { + borderColor: 'base.200', + bg: 'base.900', + color: 'base.100', + }, +}; + +const selector = createSelector( + stateSelector, + (state) => { + const gallerySelectionCount = state.gallery.selection.length; + const batchSelectionCount = state.batch.selection.length; + + return { + gallerySelectionCount, + batchSelectionCount, + }; + }, + defaultSelectorOptions +); + +const DragPreview = (props: OverlayDragImageProps) => { + const { gallerySelectionCount, batchSelectionCount } = + useAppSelector(selector); + + if (!props.dragData) { + return; + } + + if (props.dragData.payloadType === 'IMAGE_DTO') { + return ( + + + + ); + } + + if (props.dragData.payloadType === 'BATCH_SELECTION') { + return ( + + {batchSelectionCount} + Images + + ); + } + + if (props.dragData.payloadType === 'GALLERY_SELECTION') { + return ( + + {gallerySelectionCount} + Images + + ); + } + + return null; +}; + +export default memo(DragPreview); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx index 6150259f66..1b8687bf8e 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx @@ -1,8 +1,5 @@ import { - DndContext, - DragEndEvent, DragOverlay, - DragStartEvent, MouseSensor, TouchSensor, pointerWithin, @@ -10,33 +7,45 @@ import { useSensors, } from '@dnd-kit/core'; import { PropsWithChildren, memo, useCallback, useState } from 'react'; -import OverlayDragImage from './OverlayDragImage'; -import { ImageDTO } from 'services/api/types'; -import { isImageDTO } from 'services/api/guards'; +import DragPreview from './DragPreview'; import { snapCenterToCursor } from '@dnd-kit/modifiers'; import { AnimatePresence, motion } from 'framer-motion'; +import { + DndContext, + DragEndEvent, + DragStartEvent, + TypesafeDraggableData, +} from './typesafeDnd'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { imageDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped'; type ImageDndContextProps = PropsWithChildren; const ImageDndContext = (props: ImageDndContextProps) => { - const [draggedImage, setDraggedImage] = useState(null); + const [activeDragData, setActiveDragData] = + useState(null); + + const dispatch = useAppDispatch(); const handleDragStart = useCallback((event: DragStartEvent) => { - const dragData = event.active.data.current; - if (dragData && 'image' in dragData && isImageDTO(dragData.image)) { - setDraggedImage(dragData.image); + const activeData = event.active.data.current; + if (!activeData) { + return; } + setActiveDragData(activeData); }, []); const handleDragEnd = useCallback( (event: DragEndEvent) => { - const handleDrop = event.over?.data.current?.handleDrop; - if (handleDrop && typeof handleDrop === 'function' && draggedImage) { - handleDrop(draggedImage); + const activeData = event.active.data.current; + const overData = event.over?.data.current; + if (!activeData || !overData) { + return; } - setDraggedImage(null); + dispatch(imageDropped({ overData, activeData })); + setActiveDragData(null); }, - [draggedImage] + [dispatch] ); const mouseSensor = useSensor(MouseSensor, { @@ -46,6 +55,7 @@ const ImageDndContext = (props: ImageDndContextProps) => { const touchSensor = useSensor(TouchSensor, { activationConstraint: { delay: 150, tolerance: 5 }, }); + // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay // (currently the drag element collision rect is not correctly calculated) @@ -63,7 +73,7 @@ const ImageDndContext = (props: ImageDndContextProps) => { {props.children} - {draggedImage && ( + {activeDragData && ( { transition: { duration: 0.1 }, }} > - + )} diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/OverlayDragImage.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/OverlayDragImage.tsx deleted file mode 100644 index 611d1ceee9..0000000000 --- a/invokeai/frontend/web/src/app/components/ImageDnd/OverlayDragImage.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { Box, Image } from '@chakra-ui/react'; -import { memo } from 'react'; -import { ImageDTO } from 'services/api/types'; - -type OverlayDragImageProps = { - image: ImageDTO; -}; - -const OverlayDragImage = (props: OverlayDragImageProps) => { - return ( - - - - ); -}; - -export default memo(OverlayDragImage); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx new file mode 100644 index 0000000000..1478ace748 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx @@ -0,0 +1,201 @@ +// type-safe dnd from https://github.com/clauderic/dnd-kit/issues/935 +import { + Active, + Collision, + DndContextProps, + DndContext as OriginalDndContext, + Over, + Translate, + UseDraggableArguments, + UseDroppableArguments, + useDraggable as useOriginalDraggable, + useDroppable as useOriginalDroppable, +} from '@dnd-kit/core'; +import { ImageDTO } from 'services/api/types'; + +type BaseDropData = { + id: string; +}; + +export type CurrentImageDropData = BaseDropData & { + actionType: 'SET_CURRENT_IMAGE'; +}; + +export type InitialImageDropData = BaseDropData & { + actionType: 'SET_INITIAL_IMAGE'; +}; + +export type ControlNetDropData = BaseDropData & { + actionType: 'SET_CONTROLNET_IMAGE'; + context: { + controlNetId: string; + }; +}; + +export type CanvasInitialImageDropData = BaseDropData & { + actionType: 'SET_CANVAS_INITIAL_IMAGE'; +}; + +export type NodesImageDropData = BaseDropData & { + actionType: 'SET_NODES_IMAGE'; + context: { + nodeId: string; + fieldName: string; + }; +}; + +export type NodesMultiImageDropData = BaseDropData & { + actionType: 'SET_MULTI_NODES_IMAGE'; + context: { nodeId: string; fieldName: string }; +}; + +export type AddToBatchDropData = BaseDropData & { + actionType: 'ADD_TO_BATCH'; +}; + +export type MoveBoardDropData = BaseDropData & { + actionType: 'MOVE_BOARD'; + context: { boardId: string | null }; +}; + +export type TypesafeDroppableData = + | CurrentImageDropData + | InitialImageDropData + | ControlNetDropData + | CanvasInitialImageDropData + | NodesImageDropData + | AddToBatchDropData + | NodesMultiImageDropData + | MoveBoardDropData; + +type BaseDragData = { + id: string; +}; + +export type ImageDraggableData = BaseDragData & { + payloadType: 'IMAGE_DTO'; + payload: { imageDTO: ImageDTO }; +}; + +export type GallerySelectionDraggableData = BaseDragData & { + payloadType: 'GALLERY_SELECTION'; +}; + +export type BatchSelectionDraggableData = BaseDragData & { + payloadType: 'BATCH_SELECTION'; +}; + +export type TypesafeDraggableData = + | ImageDraggableData + | GallerySelectionDraggableData + | BatchSelectionDraggableData; + +interface UseDroppableTypesafeArguments + extends Omit { + data?: TypesafeDroppableData; +} + +type UseDroppableTypesafeReturnValue = Omit< + ReturnType, + 'active' | 'over' +> & { + active: TypesafeActive | null; + over: TypesafeOver | null; +}; + +export function useDroppable(props: UseDroppableTypesafeArguments) { + return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue; +} + +interface UseDraggableTypesafeArguments + extends Omit { + data?: TypesafeDraggableData; +} + +type UseDraggableTypesafeReturnValue = Omit< + ReturnType, + 'active' | 'over' +> & { + active: TypesafeActive | null; + over: TypesafeOver | null; +}; + +export function useDraggable(props: UseDraggableTypesafeArguments) { + return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue; +} + +interface TypesafeActive extends Omit { + data: React.MutableRefObject; +} + +interface TypesafeOver extends Omit { + data: React.MutableRefObject; +} + +export const isValidDrop = ( + overData: TypesafeDroppableData | undefined, + active: TypesafeActive | null +) => { + if (!overData || !active?.data.current) { + return false; + } + + const { actionType } = overData; + const { payloadType } = active.data.current; + + if (overData.id === active.data.current.id) { + return false; + } + + switch (actionType) { + case 'SET_CURRENT_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CONTROLNET_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CANVAS_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_NODES_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_MULTI_NODES_IMAGE': + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; + case 'ADD_TO_BATCH': + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; + case 'MOVE_BOARD': + return ( + payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION' + ); + default: + return false; + } +}; + +interface DragEvent { + activatorEvent: Event; + active: TypesafeActive; + collisions: Collision[] | null; + delta: Translate; + over: TypesafeOver | null; +} + +export interface DragStartEvent extends Pick {} +export interface DragMoveEvent extends DragEvent {} +export interface DragOverEvent extends DragMoveEvent {} +export interface DragEndEvent extends DragEvent {} +export interface DragCancelEvent extends DragEndEvent {} + +export interface DndContextTypesafeProps + extends Omit< + DndContextProps, + 'onDragStart' | 'onDragMove' | 'onDragOver' | 'onDragEnd' | 'onDragCancel' + > { + onDragStart?(event: DragStartEvent): void; + onDragMove?(event: DragMoveEvent): void; + onDragOver?(event: DragOverEvent): void; + onDragEnd?(event: DragEndEvent): void; + onDragCancel?(event: DragCancelEvent): void; +} +export function DndContext(props: DndContextTypesafeProps) { + return ; +} diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 7259f6105d..105f8f18d7 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -7,7 +7,6 @@ import React, { } from 'react'; import { Provider } from 'react-redux'; import { store } from 'app/store/store'; -// import { OpenAPI } from 'services/api/types'; import Loading from '../../common/components/Loading/Loading'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; @@ -17,11 +16,6 @@ import '../../i18n'; import { socketMiddleware } from 'services/events/middleware'; import { Middleware } from '@reduxjs/toolkit'; import ImageDndContext from './ImageDnd/ImageDndContext'; -import { - DeleteImageContext, - DeleteImageContextProvider, -} from 'app/contexts/DeleteImageContext'; -import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; import { $authToken, $baseUrl } from 'services/api/client'; import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext'; @@ -34,7 +28,6 @@ interface Props extends PropsWithChildren { token?: string; config?: PartialAppConfig; headerComponent?: ReactNode; - setIsReady?: (isReady: boolean) => void; middleware?: Middleware[]; } @@ -43,7 +36,6 @@ const InvokeAIUI = ({ token, config, headerComponent, - setIsReady, middleware, }: Props) => { useEffect(() => { @@ -85,17 +77,11 @@ const InvokeAIUI = ({ }> - - - - - - - + + + + + diff --git a/invokeai/frontend/web/src/app/contexts/DeleteBoardImagesContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteBoardImagesContext.tsx index 38c89bfcf9..15f9fab282 100644 --- a/invokeai/frontend/web/src/app/contexts/DeleteBoardImagesContext.tsx +++ b/invokeai/frontend/web/src/app/contexts/DeleteBoardImagesContext.tsx @@ -5,15 +5,15 @@ import { useDeleteBoardMutation } from '../../services/api/endpoints/boards'; import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions'; import { createSelector } from '@reduxjs/toolkit'; import { some } from 'lodash-es'; -import { canvasSelector } from '../../features/canvas/store/canvasSelectors'; -import { controlNetSelector } from '../../features/controlNet/store/controlNetSlice'; -import { selectImagesById } from '../../features/gallery/store/imagesSlice'; -import { nodesSelector } from '../../features/nodes/store/nodesSlice'; -import { generationSelector } from '../../features/parameters/store/generationSelectors'; +import { canvasSelector } from 'features/canvas/store/canvasSelectors'; +import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; +import { selectImagesById } from 'features/gallery/store/gallerySlice'; +import { nodesSelector } from 'features/nodes/store/nodesSlice'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; import { RootState } from '../store/store'; import { useAppDispatch, useAppSelector } from '../store/storeHooks'; import { ImageUsage } from './DeleteImageContext'; -import { requestedBoardImagesDeletion } from '../../features/gallery/store/actions'; +import { requestedBoardImagesDeletion } from 'features/gallery/store/actions'; export const selectBoardImagesUsage = createSelector( [ diff --git a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx deleted file mode 100644 index 6f4af7608f..0000000000 --- a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx +++ /dev/null @@ -1,201 +0,0 @@ -import { useDisclosure } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { requestedImageDeletion } from 'features/gallery/store/actions'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import { - PropsWithChildren, - createContext, - useCallback, - useEffect, - useState, -} from 'react'; -import { ImageDTO } from 'services/api/types'; -import { RootState } from 'app/store/store'; -import { canvasSelector } from 'features/canvas/store/canvasSelectors'; -import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; -import { nodesSelector } from 'features/nodes/store/nodesSlice'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { some } from 'lodash-es'; - -export type ImageUsage = { - isInitialImage: boolean; - isCanvasImage: boolean; - isNodesImage: boolean; - isControlNetImage: boolean; -}; - -export const selectImageUsage = createSelector( - [ - generationSelector, - canvasSelector, - nodesSelector, - controlNetSelector, - (state: RootState, image_name?: string) => image_name, - ], - (generation, canvas, nodes, controlNet, image_name) => { - const isInitialImage = generation.initialImage?.imageName === image_name; - - const isCanvasImage = canvas.layerState.objects.some( - (obj) => obj.kind === 'image' && obj.imageName === image_name - ); - - const isNodesImage = nodes.nodes.some((node) => { - return some( - node.data.inputs, - (input) => input.type === 'image' && input.value === image_name - ); - }); - - const isControlNetImage = some( - controlNet.controlNets, - (c) => - c.controlImage === image_name || c.processedControlImage === image_name - ); - - const imageUsage: ImageUsage = { - isInitialImage, - isCanvasImage, - isNodesImage, - isControlNetImage, - }; - - return imageUsage; - }, - defaultSelectorOptions -); - -type DeleteImageContextValue = { - /** - * Whether the delete image dialog is open. - */ - isOpen: boolean; - /** - * Closes the delete image dialog. - */ - onClose: () => void; - /** - * Opens the delete image dialog and handles all deletion-related checks. - */ - onDelete: (image?: ImageDTO) => void; - /** - * The image pending deletion - */ - image?: ImageDTO; - /** - * The features in which this image is used - */ - imageUsage?: ImageUsage; - /** - * Immediately deletes an image. - * - * You probably don't want to use this - use `onDelete` instead. - */ - onImmediatelyDelete: () => void; -}; - -export const DeleteImageContext = createContext({ - isOpen: false, - onClose: () => undefined, - onImmediatelyDelete: () => undefined, - onDelete: () => undefined, -}); - -const selector = createSelector( - [systemSelector], - (system) => { - const { isProcessing, isConnected, shouldConfirmOnDelete } = system; - - return { - canDeleteImage: isConnected && !isProcessing, - shouldConfirmOnDelete, - }; - }, - defaultSelectorOptions -); - -type Props = PropsWithChildren; - -export const DeleteImageContextProvider = (props: Props) => { - const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector); - const [imageToDelete, setImageToDelete] = useState(); - const dispatch = useAppDispatch(); - const { isOpen, onOpen, onClose } = useDisclosure(); - - // Check where the image to be deleted is used (eg init image, controlnet, etc.) - const imageUsage = useAppSelector((state) => - selectImageUsage(state, imageToDelete?.image_name) - ); - - // Clean up after deleting or dismissing the modal - const closeAndClearImageToDelete = useCallback(() => { - setImageToDelete(undefined); - onClose(); - }, [onClose]); - - // Dispatch the actual deletion action, to be handled by listener middleware - const handleActualDeletion = useCallback( - (image: ImageDTO) => { - dispatch(requestedImageDeletion({ image, imageUsage })); - closeAndClearImageToDelete(); - }, - [closeAndClearImageToDelete, dispatch, imageUsage] - ); - - // This is intended to be called by the delete button in the dialog - const onImmediatelyDelete = useCallback(() => { - if (canDeleteImage && imageToDelete) { - handleActualDeletion(imageToDelete); - } - closeAndClearImageToDelete(); - }, [ - canDeleteImage, - imageToDelete, - closeAndClearImageToDelete, - handleActualDeletion, - ]); - - const handleGatedDeletion = useCallback( - (image: ImageDTO) => { - if (shouldConfirmOnDelete || some(imageUsage)) { - // If we should confirm on delete, or if the image is in use, open the dialog - onOpen(); - } else { - handleActualDeletion(image); - } - }, - [imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion] - ); - - // Consumers of the context call this to delete an image - const onDelete = useCallback((image?: ImageDTO) => { - if (!image) { - return; - } - // Set the image to delete, then let the effect call the actual deletion - setImageToDelete(image); - }, []); - - useEffect(() => { - // We need to use an effect here to trigger the image usage selector, else we get a stale value - if (imageToDelete) { - handleGatedDeletion(imageToDelete); - } - }, [handleGatedDeletion, imageToDelete]); - - return ( - - {props.children} - - ); -}; diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index cb18d48301..ac1b9c5205 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -20,10 +20,8 @@ const serializationDenylist: { nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, - // config: configPersistDenyList, ui: uiPersistDenylist, controlNet: controlNetDenylist, - // hotkeys: hotkeysPersistDenylist, }; export const serialize: SerializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 8f40b0bb59..23e6448987 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -1,7 +1,6 @@ import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialControlNetState } from 'features/controlNet/store/controlNetSlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice'; -import { initialImagesState } from 'features/gallery/store/imagesSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; @@ -26,7 +25,6 @@ const initialStates: { config: initialConfigState, ui: initialUIState, hotkeys: initialHotkeysState, - images: initialImagesState, controlNet: initialControlNetState, }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index a36141fafc..900fabfee9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -72,7 +72,6 @@ import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingA import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; -import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect'; import { addImageAddedToBoardFulfilledListener, addImageAddedToBoardRejectedListener, @@ -84,6 +83,9 @@ import { } from './listeners/imageRemovedFromBoard'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted'; +import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch'; +import { addImageDroppedListener } from './listeners/imageDropped'; +import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected'; export const listenerMiddleware = createListenerMiddleware(); @@ -126,6 +128,7 @@ addImageDeletedPendingListener(); addImageDeletedFulfilledListener(); addImageDeletedRejectedListener(); addRequestedBoardImageDeletionListener(); +addImageToDeleteSelectedListener(); // Image metadata addImageMetadataReceivedFulfilledListener(); @@ -211,3 +214,9 @@ addBoardIdSelectedListener(); // Node schemas addReceivedOpenAPISchemaListener(); + +// Batches +addSelectionAddedToBatchListener(); + +// DND +addImageDroppedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts index 1c96c5700d..6ce6665cc5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -1,12 +1,14 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; -import { boardIdSelected } from 'features/gallery/store/boardSlice'; -import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { + imageSelected, + selectImagesAll, + boardIdSelected, +} from 'features/gallery/store/gallerySlice'; import { IMAGES_PER_PAGE, receivedPageOfImages, } from 'services/api/thunks/image'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; import { boardsApi } from 'services/api/endpoints/boards'; const moduleLog = log.child({ namespace: 'boards' }); @@ -28,7 +30,7 @@ export const addBoardIdSelectedListener = () => { return; } - const { categories } = state.images; + const { categories } = state.gallery; const filteredImages = allImages.filter((i) => { const isInCategory = categories.includes(i.image_category); @@ -47,7 +49,7 @@ export const addBoardIdSelectedListener = () => { return; } - dispatch(imageSelected(board.cover_image_name)); + dispatch(imageSelected(board.cover_image_name ?? null)); // if we haven't loaded one full page of images from this board, load more if ( @@ -77,7 +79,7 @@ export const addBoardIdSelected_changeSelectedImage_listener = () => { return; } - const { categories } = state.images; + const { categories } = state.gallery; const filteredImages = selectImagesAll(state).filter((i) => { const isInCategory = categories.includes(i.image_category); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardImagesDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardImagesDeleted.ts index c4d3c5f0ba..4b48aa4626 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardImagesDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardImagesDeleted.ts @@ -1,11 +1,11 @@ import { requestedBoardImagesDeletion } from 'features/gallery/store/actions'; import { startAppListening } from '..'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; import { + imageSelected, imagesRemoved, selectImagesAll, selectImagesById, -} from 'features/gallery/store/imagesSlice'; +} from 'features/gallery/store/gallerySlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; @@ -22,12 +22,15 @@ export const addRequestedBoardImageDeletionListener = () => { const { board_id } = board; const state = getState(); - const selectedImage = state.gallery.selectedImage - ? selectImagesById(state, state.gallery.selectedImage) + const selectedImageName = + state.gallery.selection[state.gallery.selection.length - 1]; + + const selectedImage = selectedImageName + ? selectImagesById(state, selectedImageName) : undefined; if (selectedImage && selectedImage.board_id === board_id) { - dispatch(imageSelected()); + dispatch(imageSelected(null)); } // We need to reset the features where the board images are in use - none of these work if their image(s) don't exist diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index af55a1382e..610d89873f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger'; import { imageUploaded } from 'services/api/thunks/image'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { addToast } from 'features/system/store/systemSlice'; -import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { imageUpserted } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts index 25b7b7c11f..178cb3c835 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts @@ -3,8 +3,8 @@ import { startAppListening } from '..'; import { receivedPageOfImages } from 'services/api/thunks/image'; import { imageCategoriesChanged, - selectFilteredImagesAsArray, -} from 'features/gallery/store/imagesSlice'; + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'gallery' }); @@ -13,7 +13,7 @@ export const addImageCategoriesChangedListener = () => { actionCreator: imageCategoriesChanged, effect: (action, { getState, dispatch }) => { const state = getState(); - const filteredImagesCount = selectFilteredImagesAsArray(state).length; + const filteredImagesCount = selectFilteredImages(state).length; if (!filteredImagesCount) { dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 91cd509ca6..f083a716a4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,18 +1,21 @@ -import { requestedImageDeletion } from 'features/gallery/store/actions'; -import { startAppListening } from '..'; -import { imageDeleted } from 'services/api/thunks/image'; import { log } from 'app/logging/useLogger'; -import { clamp } from 'lodash-es'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { - imageRemoved, - selectImagesIds, -} from 'features/gallery/store/imagesSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; -import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { + imageRemoved, + imageSelected, + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; +import { + imageDeletionConfirmed, + isModalOpenChanged, +} from 'features/imageDeletion/store/imageDeletionSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { clamp } from 'lodash-es'; import { api } from 'services/api'; +import { imageDeleted } from 'services/api/thunks/image'; +import { startAppListening } from '..'; const moduleLog = log.child({ namespace: 'image' }); @@ -21,17 +24,22 @@ const moduleLog = log.child({ namespace: 'image' }); */ export const addRequestedImageDeletionListener = () => { startAppListening({ - actionCreator: requestedImageDeletion, + actionCreator: imageDeletionConfirmed, effect: async (action, { dispatch, getState, condition }) => { - const { image, imageUsage } = action.payload; + const { imageDTO, imageUsage } = action.payload; - const { image_name } = image; + dispatch(isModalOpenChanged(false)); + + const { image_name } = imageDTO; const state = getState(); - const selectedImage = state.gallery.selectedImage; + const lastSelectedImage = + state.gallery.selection[state.gallery.selection.length - 1]; - if (selectedImage === image_name) { - const ids = selectImagesIds(state); + if (lastSelectedImage === image_name) { + const filteredImages = selectFilteredImages(state); + + const ids = filteredImages.map((i) => i.image_name); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name @@ -50,7 +58,7 @@ export const addRequestedImageDeletionListener = () => { if (newSelectedImageId) { dispatch(imageSelected(newSelectedImageId as string)); } else { - dispatch(imageSelected()); + dispatch(imageSelected(null)); } } @@ -88,7 +96,7 @@ export const addRequestedImageDeletionListener = () => { if (wasImageDeleted) { dispatch( - api.util.invalidateTags([{ type: 'Board', id: image.board_id }]) + api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }]) ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts new file mode 100644 index 0000000000..24a5bffec7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -0,0 +1,188 @@ +import { createAction } from '@reduxjs/toolkit'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'app/components/ImageDnd/typesafeDnd'; +import { log } from 'app/logging/useLogger'; +import { + imageAddedToBatch, + imagesAddedToBatch, +} from 'features/batch/store/batchSlice'; +import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { + fieldValueChanged, + imageCollectionFieldValueChanged, +} from 'features/nodes/store/nodesSlice'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { boardImagesApi } from 'services/api/endpoints/boardImages'; +import { startAppListening } from '../'; + +const moduleLog = log.child({ namespace: 'dnd' }); + +export const imageDropped = createAction<{ + overData: TypesafeDroppableData; + activeData: TypesafeDraggableData; +}>('dnd/imageDropped'); + +export const addImageDroppedListener = () => { + startAppListening({ + actionCreator: imageDropped, + effect: (action, { dispatch, getState }) => { + const { activeData, overData } = action.payload; + const { actionType } = overData; + const state = getState(); + + // set current image + if ( + actionType === 'SET_CURRENT_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + dispatch(imageSelected(activeData.payload.imageDTO.image_name)); + } + + // set initial image + if ( + actionType === 'SET_INITIAL_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + dispatch(initialImageChanged(activeData.payload.imageDTO)); + } + + // add image to batch + if ( + actionType === 'ADD_TO_BATCH' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + dispatch(imageAddedToBatch(activeData.payload.imageDTO.image_name)); + } + + // add multiple images to batch + if ( + actionType === 'ADD_TO_BATCH' && + activeData.payloadType === 'GALLERY_SELECTION' + ) { + dispatch(imagesAddedToBatch(state.gallery.selection)); + } + + // set control image + if ( + actionType === 'SET_CONTROLNET_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + const { controlNetId } = overData.context; + dispatch( + controlNetImageChanged({ + controlImage: activeData.payload.imageDTO.image_name, + controlNetId, + }) + ); + } + + // set canvas image + if ( + actionType === 'SET_CANVAS_INITIAL_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + dispatch(setInitialCanvasImage(activeData.payload.imageDTO)); + } + + // set nodes image + if ( + actionType === 'SET_NODES_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + const { fieldName, nodeId } = overData.context; + dispatch( + fieldValueChanged({ + nodeId, + fieldName, + value: activeData.payload.imageDTO, + }) + ); + } + + // set multiple nodes images (single image handler) + if ( + actionType === 'SET_MULTI_NODES_IMAGE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + const { fieldName, nodeId } = overData.context; + dispatch( + fieldValueChanged({ + nodeId, + fieldName, + value: [activeData.payload.imageDTO], + }) + ); + } + + // set multiple nodes images (multiple images handler) + if ( + actionType === 'SET_MULTI_NODES_IMAGE' && + activeData.payloadType === 'GALLERY_SELECTION' + ) { + const { fieldName, nodeId } = overData.context; + dispatch( + imageCollectionFieldValueChanged({ + nodeId, + fieldName, + value: state.gallery.selection.map((image_name) => ({ + image_name, + })), + }) + ); + } + + // remove image from board + // TODO: remove board_id from `removeImageFromBoard()` endpoint + // TODO: handle multiple images + // if ( + // actionType === 'MOVE_BOARD' && + // activeData.payloadType === 'IMAGE_DTO' && + // activeData.payload.imageDTO && + // overData.boardId !== null + // ) { + // const { image_name } = activeData.payload.imageDTO; + // dispatch( + // boardImagesApi.endpoints.removeImageFromBoard.initiate({ image_name }) + // ); + // } + + // add image to board + if ( + actionType === 'MOVE_BOARD' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO && + overData.context.boardId + ) { + const { image_name } = activeData.payload.imageDTO; + const { boardId } = overData.context; + dispatch( + boardImagesApi.endpoints.addImageToBoard.initiate({ + image_name, + board_id: boardId, + }) + ); + } + + // add multiple images to board + // TODO: add endpoint + // if ( + // actionType === 'ADD_TO_BATCH' && + // activeData.payloadType === 'IMAGE_NAMES' && + // activeData.payload.imageDTONames + // ) { + // dispatch(boardImagesApi.endpoints.addImagesToBoard.intiate({})); + // } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts index 24265faaa9..19af5b24c3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -1,7 +1,7 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image'; -import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { imageUpserted } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'image' }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts new file mode 100644 index 0000000000..531981126a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts @@ -0,0 +1,40 @@ +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { + imageDeletionConfirmed, + imageToDeleteSelected, + isModalOpenChanged, + selectImageUsage, +} from 'features/imageDeletion/store/imageDeletionSlice'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageToDeleteSelectedListener = () => { + startAppListening({ + actionCreator: imageToDeleteSelected, + effect: async (action, { dispatch, getState, condition }) => { + const imageDTO = action.payload; + const state = getState(); + const { shouldConfirmOnDelete } = state.system; + const imageUsage = selectImageUsage(getState()); + + if (!imageUsage) { + // should never happen + return; + } + + const isImageInUse = + imageUsage.isCanvasImage || + imageUsage.isInitialImage || + imageUsage.isControlNetImage || + imageUsage.isNodesImage; + + if (shouldConfirmOnDelete || isImageInUse) { + dispatch(isModalOpenChanged(true)); + return; + } + + dispatch(imageDeletionConfirmed({ imageDTO, imageUsage })); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index f55ed11c8f..0cd852c3de 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -2,11 +2,12 @@ import { startAppListening } from '..'; import { imageUploaded } from 'services/api/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; import { log } from 'app/logging/useLogger'; -import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { imageUpserted } from 'features/gallery/store/gallerySlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { imageAddedToBatch } from 'features/batch/store/batchSlice'; const moduleLog = log.child({ namespace: 'image' }); @@ -70,6 +71,11 @@ export const addImageUploadedFulfilledListener = () => { dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); return; } + + if (postUploadAction?.type === 'ADD_TO_BATCH') { + dispatch(imageAddedToBatch(image.image_name)); + return; + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts index c663c64361..0d8aa3d7c9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts @@ -1,7 +1,7 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { imageUrlsReceived } from 'services/api/thunks/image'; -import { imageUpdatedOne } from 'features/gallery/store/imagesSlice'; +import { imageUpdatedOne } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'image' }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index 9aca82a32b..fe1a9bd806 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -4,7 +4,7 @@ import { addToast } from 'features/system/store/systemSlice'; import { startAppListening } from '..'; import { initialImageSelected } from 'features/parameters/store/actions'; import { makeToast } from 'app/components/Toaster'; -import { selectImagesById } from 'features/gallery/store/imagesSlice'; +import { selectImagesById } from 'features/gallery/store/gallerySlice'; import { isImageDTO } from 'services/api/guards'; export const addInitialImageSelectedListener = () => { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts index e357d38dc3..3c11916be0 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts @@ -2,6 +2,7 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { serializeError } from 'serialize-error'; import { receivedPageOfImages } from 'services/api/thunks/image'; +import { imagesApi } from 'services/api/endpoints/images'; const moduleLog = log.child({ namespace: 'gallery' }); @@ -9,11 +10,17 @@ export const addReceivedPageOfImagesFulfilledListener = () => { startAppListening({ actionCreator: receivedPageOfImages.fulfilled, effect: (action, { getState, dispatch }) => { - const page = action.payload; + const { items } = action.payload; moduleLog.debug( { data: { payload: action.payload } }, - `Received ${page.items.length} images` + `Received ${items.length} images` ); + + items.forEach((image) => { + dispatch( + imagesApi.util.upsertQueryData('getImageDTO', image.image_name, image) + ); + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/selectionAddedToBatch.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/selectionAddedToBatch.ts new file mode 100644 index 0000000000..dae72d92e7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/selectionAddedToBatch.ts @@ -0,0 +1,19 @@ +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { + imagesAddedToBatch, + selectionAddedToBatch, +} from 'features/batch/store/batchSlice'; + +const moduleLog = log.child({ namespace: 'batch' }); + +export const addSelectionAddedToBatchListener = () => { + startAppListening({ + actionCreator: selectionAddedToBatch, + effect: (action, { dispatch, getState }) => { + const { selection } = getState().gallery; + + dispatch(imagesAddedToBatch(selection)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 976c1558d0..fe4bce682b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,6 +1,5 @@ import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; -import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { startAppListening } from '../..'; @@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { nodes, config, images } = getState(); + const { nodes, config } = getState(); const { disabledTabs } = config; - if (!images.ids.length) { - dispatch( - receivedPageOfImages({ - categories: ['general'], - is_intermediate: false, - }) - ); - } - if (!nodes.schema && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts index bc2c1d1c27..36840e5de1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -2,7 +2,7 @@ import { stagingAreaImageSaved } from 'features/canvas/store/actions'; import { startAppListening } from '..'; import { log } from 'app/logging/useLogger'; import { imageUpdated } from 'services/api/thunks/image'; -import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { imageUpserted } from 'features/gallery/store/gallerySlice'; import { addToast } from 'features/system/store/systemSlice'; const moduleLog = log.child({ namespace: 'canvas' }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts index 670d762d24..490d99290d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts @@ -8,7 +8,7 @@ import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { forEach, uniqBy } from 'lodash-es'; import { imageUrlsReceived } from 'services/api/thunks/image'; import { log } from 'app/logging/useLogger'; -import { selectImagesEntities } from 'features/gallery/store/imagesSlice'; +import { selectImagesEntities } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'images' }); @@ -36,7 +36,7 @@ const selectAllUsedImages = createSelector( nodes.nodes.forEach((node) => { forEach(node.data.inputs, (input) => { if (input.type === 'image' && input.value) { - allUsedImages.push(input.value); + allUsedImages.push(input.value.image_name); } }); }); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e92a422d68..5208933e7b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,31 +8,32 @@ import { import dynamicMiddlewares from 'redux-dynamic-middlewares'; import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import batchReducer from 'features/batch/store/batchSlice'; import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; +import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; -import imagesReducer from 'features/gallery/store/imagesSlice'; +import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; +import loraReducer from 'features/lora/store/loraSlice'; +import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; -import systemReducer from 'features/system/store/systemSlice'; -// import sessionReducer from 'features/system/store/sessionSlice'; -import nodesReducer from 'features/nodes/store/nodesSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; +import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; -import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { actionsDenylist } from './middleware/devtools/actionsDenylist'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { api } from 'services/api'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { api } from 'services/api'; +import { actionSanitizer } from './middleware/devtools/actionSanitizer'; +import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; const allReducers = { canvas: canvasReducer, @@ -45,11 +46,12 @@ const allReducers = { config: configReducer, ui: uiReducer, hotkeys: hotkeysReducer, - images: imagesReducer, controlNet: controlNetReducer, boards: boardsReducer, - // session: sessionReducer, dynamicPrompts: dynamicPromptsReducer, + batch: batchReducer, + imageDeletion: imageDeletionReducer, + lora: loraReducer, [api.reducerPath]: api.reducer, }; @@ -68,6 +70,8 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'ui', 'controlNet', 'dynamicPrompts', + 'batch', + 'lora', // 'boards', // 'hotkeys', // 'config', diff --git a/invokeai/frontend/web/src/common/components/IAIButton.tsx b/invokeai/frontend/web/src/common/components/IAIButton.tsx index 3efae76d1e..d1e77537cc 100644 --- a/invokeai/frontend/web/src/common/components/IAIButton.tsx +++ b/invokeai/frontend/web/src/common/components/IAIButton.tsx @@ -15,10 +15,25 @@ export interface IAIButtonProps extends ButtonProps { } const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => { - const { children, tooltip = '', tooltipProps, isChecked, ...rest } = props; + const { + children, + tooltip = '', + tooltipProps: { placement = 'top', hasArrow = true, ...tooltipProps } = {}, + isChecked, + ...rest + } = props; return ( - - diff --git a/invokeai/frontend/web/src/common/components/IAICollapse.tsx b/invokeai/frontend/web/src/common/components/IAICollapse.tsx index 5db26f3841..09dc1392e2 100644 --- a/invokeai/frontend/web/src/common/components/IAICollapse.tsx +++ b/invokeai/frontend/web/src/common/components/IAICollapse.tsx @@ -4,22 +4,25 @@ import { Collapse, Flex, Spacer, - Switch, + Text, useColorMode, + useDisclosure, } from '@chakra-ui/react'; +import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo } from 'react'; import { mode } from 'theme/util/mode'; export type IAIToggleCollapseProps = PropsWithChildren & { label: string; - isOpen: boolean; - onToggle: () => void; - withSwitch?: boolean; + activeLabel?: string; + defaultIsOpen?: boolean; }; const IAICollapse = (props: IAIToggleCollapseProps) => { - const { label, isOpen, onToggle, children, withSwitch = false } = props; + const { label, activeLabel, children, defaultIsOpen = false } = props; + const { isOpen, onToggle } = useDisclosure({ defaultIsOpen }); const { colorMode } = useColorMode(); + return ( { alignItems: 'center', p: 2, px: 4, + gap: 2, borderTopRadius: 'base', borderBottomRadius: isOpen ? 0 : 'base', bg: isOpen @@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { }} > {label} + + {activeLabel && ( + + + {activeLabel} + + + )} + - {withSwitch && } - {!withSwitch && ( - - )} + void; - onReset?: () => void; + imageDTO: ImageDTO | undefined; onError?: (event: SyntheticEvent) => void; onLoad?: (event: SyntheticEvent) => void; - resetIconSize?: IconButtonProps['size']; + onClick?: (event: MouseEvent) => void; + onClickReset?: (event: MouseEvent) => void; withResetIcon?: boolean; + resetIcon?: ReactElement; + resetTooltip?: string; withMetadataOverlay?: boolean; isDragDisabled?: boolean; isDropDisabled?: boolean; isUploadDisabled?: boolean; - fallback?: ReactElement; - payloadImage?: ImageDTO | null | undefined; minSize?: number; postUploadAction?: PostUploadAction; imageSx?: ChakraProps['sx']; fitContainer?: boolean; + droppableData?: TypesafeDroppableData; + draggableData?: TypesafeDraggableData; + dropLabel?: string; + isSelected?: boolean; + thumbnail?: boolean; + noContentFallback?: ReactElement; }; const IAIDndImage = (props: IAIDndImageProps) => { const { - image, - onDrop, - onReset, + imageDTO, + onClickReset, onError, - resetIconSize = 'md', + onClick, withResetIcon = false, withMetadataOverlay = false, isDropDisabled = false, isDragDisabled = false, isUploadDisabled = false, - fallback = , - payloadImage, minSize = 24, postUploadAction, imageSx, fitContainer = false, + droppableData, + draggableData, + dropLabel, + isSelected = false, + thumbnail = false, + resetTooltip = 'Reset', + resetIcon = , + noContentFallback = , } = props; - const dndId = useRef(uuidv4()); const { colorMode } = useColorMode(); - const { - isOver, - setNodeRef: setDroppableRef, - active: isDropActive, - } = useDroppable({ - id: dndId.current, - disabled: isDropDisabled, - data: { - handleDrop: onDrop, - }, - }); + const dndId = useRef(uuidv4()); const { attributes, listeners, setNodeRef: setDraggableRef, isDragging, + active, } = useDraggable({ id: dndId.current, - data: { - image: payloadImage ? payloadImage : image, - }, - disabled: isDragDisabled || !image, + disabled: isDragDisabled || !imageDTO, + data: draggableData, }); + const { isOver, setNodeRef: setDroppableRef } = useDroppable({ + id: dndId.current, + disabled: isDropDisabled, + data: droppableData, + }); + + const setDndRef = useCombinedRefs(setDroppableRef, setDraggableRef); + const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ postUploadAction, isDisabled: isUploadDisabled, }); - const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef); + const resetIconShadow = useColorModeValue( + `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`, + `drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))` + ); const uploadButtonStyles = isUploadDisabled ? {} @@ -117,16 +134,16 @@ const IAIDndImage = (props: IAIDndImageProps) => { alignItems: 'center', justifyContent: 'center', position: 'relative', - minW: minSize, - minH: minSize, + minW: minSize ? minSize : undefined, + minH: minSize ? minSize : undefined, userSelect: 'none', - cursor: isDragDisabled || !image ? 'auto' : 'grab', + cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer', }} {...attributes} {...listeners} - ref={setNodeRef} + ref={setDndRef} > - {image && ( + {imageDTO && ( { }} > } onError={onError} - objectFit="contain" draggable={false} sx={{ + objectFit: 'contain', maxW: 'full', maxH: 'full', borderRadius: 'base', + shadow: isSelected ? 'selected.light' : undefined, + _dark: { shadow: isSelected ? 'selected.dark' : undefined }, ...imageSx, }} /> - {withMetadataOverlay && } - {onReset && withResetIcon && ( - } + {onClickReset && withResetIcon && ( + - } - onClick={onReset} - /> - + /> )} - - {isDropActive && } - )} - {!image && ( + {!imageDTO && !isUploadDisabled && ( <> { > - - {isDropActive && } - )} + {!imageDTO && isUploadDisabled && noContentFallback} + + {isValidDrop(droppableData, active) && !isDragging && ( + + )} + ); }; diff --git a/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx b/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx index 8ae54c30ab..573a900fef 100644 --- a/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx @@ -62,7 +62,7 @@ export const IAIDropOverlay = (props: Props) => { w: 'full', h: 'full', opacity: 1, - borderWidth: 2, + borderWidth: 3, borderColor: isOver ? mode('base.50', 'base.200')(colorMode) : mode('base.100', 'base.500')(colorMode), @@ -78,10 +78,10 @@ export const IAIDropOverlay = (props: Props) => { sx={{ fontSize: '2xl', fontWeight: 600, - transform: isOver ? 'scale(1.1)' : 'scale(1)', + transform: isOver ? 'scale(1.02)' : 'scale(1)', color: isOver - ? mode('base.100', 'base.100')(colorMode) - : mode('base.200', 'base.500')(colorMode), + ? mode('base.50', 'base.50')(colorMode) + : mode('base.100', 'base.200')(colorMode), transitionProperty: 'common', transitionDuration: '0.1s', }} diff --git a/invokeai/frontend/web/src/common/components/IAIIconButton.tsx b/invokeai/frontend/web/src/common/components/IAIIconButton.tsx index 8ea06a1328..ed1514055e 100644 --- a/invokeai/frontend/web/src/common/components/IAIIconButton.tsx +++ b/invokeai/frontend/web/src/common/components/IAIIconButton.tsx @@ -29,7 +29,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => { diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index 4cff351aee..a07071ee79 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -1,73 +1,82 @@ import { As, + ChakraProps, Flex, - FlexProps, Icon, - IconProps, + Skeleton, Spinner, - SpinnerProps, - useColorMode, + StyleProps, + Text, } from '@chakra-ui/react'; import { FaImage } from 'react-icons/fa'; -import { mode } from 'theme/util/mode'; +import { ImageDTO } from 'services/api/types'; -type Props = FlexProps & { - spinnerProps?: SpinnerProps; -}; +type Props = { image: ImageDTO | undefined }; + +export const IAILoadingImageFallback = (props: Props) => { + if (props.image) { + return ( + + ); + } -export const IAIImageLoadingFallback = (props: Props) => { - const { spinnerProps, ...rest } = props; - const { sx, ...restFlexProps } = rest; - const { colorMode } = useColorMode(); return ( - + ); }; type IAINoImageFallbackProps = { - flexProps?: FlexProps; - iconProps?: IconProps; - as?: As; + label?: string; + icon?: As; + boxSize?: StyleProps['boxSize']; + sx?: ChakraProps['sx']; }; -export const IAINoImageFallback = (props: IAINoImageFallbackProps) => { - const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} }; - const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} }; - const { colorMode } = useColorMode(); +export const IAINoContentFallback = (props: IAINoImageFallbackProps) => { + const { icon = FaImage, boxSize = 16 } = props; return ( - + + {props.label && {props.label}} ); }; diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 39ec6fd245..9a0bc865a4 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -1,15 +1,16 @@ import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { memo } from 'react'; +import { RefObject, memo } from 'react'; import { mode } from 'theme/util/mode'; type IAIMultiSelectProps = MultiSelectProps & { tooltip?: string; + inputRef?: RefObject; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, ...rest } = props; + const { searchable = true, tooltip, inputRef, ...rest } = props; const { base50, base100, @@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( ({ label: { @@ -61,7 +63,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx index 9b023fd2d7..585dc106a8 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx @@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx index 54a3b30a4f..d25ab0d87e 100644 --- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx +++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx @@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => { isDisabled={isDisabled} width={width} display="flex" - gap={4} alignItems="center" {...formControlProps} > @@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => { sx={{ cursor: isDisabled ? 'not-allowed' : 'pointer', ...formLabelProps?.sx, + pe: 4, }} {...formLabelProps} > diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index d410c3917c..605aa8b162 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -1,27 +1,49 @@ import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { systemSelector } from 'features/system/store/systemSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { + modelsApi, + useGetMainModelsQuery, +} from '../../services/api/endpoints/models'; const readinessSelector = createSelector( - [generationSelector, systemSelector, activeTabNameSelector], - (generation, system, activeTabName) => { + [stateSelector, activeTabNameSelector], + (state, activeTabName) => { + const { generation, system, batch } = state; const { shouldGenerateVariations, seedWeights, initialImage, seed } = generation; const { isProcessing, isConnected } = system; + const { + isEnabled: isBatchEnabled, + asInitialImage, + imageNames: batchImageNames, + } = batch; let isReady = true; const reasonsWhyNotReady: string[] = []; - if (activeTabName === 'img2img' && !initialImage) { + if ( + activeTabName === 'img2img' && + !initialImage && + !(asInitialImage && batchImageNames.length > 1) + ) { isReady = false; reasonsWhyNotReady.push('No initial image selected'); } + const { isSuccess: mainModelsSuccessfullyLoaded } = + modelsApi.endpoints.getMainModels.select()(state); + if (!mainModelsSuccessfullyLoaded) { + isReady = false; + reasonsWhyNotReady.push('Models are not loaded'); + } + // TODO: job queue // Cannot generate if already processing an image if (isProcessing) { diff --git a/invokeai/frontend/web/src/features/batch/components/BatchControlNet.tsx b/invokeai/frontend/web/src/features/batch/components/BatchControlNet.tsx new file mode 100644 index 0000000000..4231c84bec --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/components/BatchControlNet.tsx @@ -0,0 +1,67 @@ +import { + Flex, + FormControl, + FormLabel, + Heading, + Spacer, + Switch, + Text, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISwitch from 'common/components/IAISwitch'; +import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice'; +import { ChangeEvent, memo, useCallback } from 'react'; +import { controlNetToggled } from '../store/batchSlice'; + +type Props = { + controlNet: ControlNetConfig; +}; + +const selector = createSelector( + [stateSelector, (state, controlNetId: string) => controlNetId], + (state, controlNetId) => { + const isControlNetEnabled = state.batch.controlNets.includes(controlNetId); + return { isControlNetEnabled }; + }, + defaultSelectorOptions +); + +const BatchControlNet = (props: Props) => { + const dispatch = useAppDispatch(); + const { isControlNetEnabled } = useAppSelector((state) => + selector(state, props.controlNet.controlNetId) + ); + const { processorType, model } = props.controlNet; + + const handleChangeAsControlNet = useCallback(() => { + dispatch(controlNetToggled(props.controlNet.controlNetId)); + }, [dispatch, props.controlNet.controlNetId]); + + return ( + + + + + ControlNet + + + + + + + Model: {model} + + + Processor: {processorType} + + + ); +}; + +export default memo(BatchControlNet); diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx new file mode 100644 index 0000000000..4a6250f93a --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx @@ -0,0 +1,116 @@ +import { Box, Icon, Skeleton } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { + batchImageRangeEndSelected, + batchImageSelected, + batchImageSelectionToggled, + imageRemovedFromBatch, +} from 'features/batch/store/batchSlice'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { FaExclamationCircle } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; + +const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + (state) => ({ + selectionCount: state.batch.selection.length, + isSelected: state.batch.selection.includes(image_name), + }), + defaultSelectorOptions + ); + +type BatchImageProps = { + imageName: string; +}; + +const BatchImage = (props: BatchImageProps) => { + const { + currentData: imageDTO, + isFetching, + isError, + isSuccess, + } = useGetImageDTOQuery(props.imageName); + const dispatch = useAppDispatch(); + + const selector = useMemo( + () => makeSelector(props.imageName), + [props.imageName] + ); + + const { isSelected, selectionCount } = useAppSelector(selector); + + const handleClickRemove = useCallback(() => { + dispatch(imageRemovedFromBatch(props.imageName)); + }, [dispatch, props.imageName]); + + const handleClick = useCallback( + (e: MouseEvent) => { + if (e.shiftKey) { + dispatch(batchImageRangeEndSelected(props.imageName)); + } else if (e.ctrlKey || e.metaKey) { + dispatch(batchImageSelectionToggled(props.imageName)); + } else { + dispatch(batchImageSelected(props.imageName)); + } + }, + [dispatch, props.imageName] + ); + + const draggableData = useMemo(() => { + if (selectionCount > 1) { + return { + id: 'batch', + payloadType: 'BATCH_SELECTION', + }; + } + + if (imageDTO) { + return { + id: 'batch', + payloadType: 'IMAGE_DTO', + payload: { imageDTO }, + }; + } + }, [imageDTO, selectionCount]); + + if (isError) { + return ; + } + + if (isFetching) { + return ( + + + + ); + } + + return ( + + + + ); +}; + +export default memo(BatchImage); diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImageContainer.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImageContainer.tsx new file mode 100644 index 0000000000..09e6b8afd7 --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/components/BatchImageContainer.tsx @@ -0,0 +1,31 @@ +import { Box } from '@chakra-ui/react'; +import BatchImageGrid from './BatchImageGrid'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; +import { + AddToBatchDropData, + isValidDrop, + useDroppable, +} from 'app/components/ImageDnd/typesafeDnd'; + +const droppableData: AddToBatchDropData = { + id: 'batch', + actionType: 'ADD_TO_BATCH', +}; + +const BatchImageContainer = () => { + const { isOver, setNodeRef, active } = useDroppable({ + id: 'batch-manager', + data: droppableData, + }); + + return ( + + + {isValidDrop(droppableData, active) && ( + + )} + + ); +}; + +export default BatchImageContainer; diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImageGrid.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImageGrid.tsx new file mode 100644 index 0000000000..f61d27d4cf --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/components/BatchImageGrid.tsx @@ -0,0 +1,54 @@ +import { FaImages } from 'react-icons/fa'; +import { Grid, GridItem } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import BatchImage from './BatchImage'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; + +const selector = createSelector( + stateSelector, + (state) => { + const imageNames = state.batch.imageNames.concat().reverse(); + + return { imageNames }; + }, + defaultSelectorOptions +); + +const BatchImageGrid = () => { + const { imageNames } = useAppSelector(selector); + + if (imageNames.length === 0) { + return ( + + ); + } + + return ( + + {imageNames.map((imageName) => ( + + + + ))} + + ); +}; + +export default BatchImageGrid; diff --git a/invokeai/frontend/web/src/features/batch/components/BatchManager.tsx b/invokeai/frontend/web/src/features/batch/components/BatchManager.tsx new file mode 100644 index 0000000000..d7855dd4e2 --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/components/BatchManager.tsx @@ -0,0 +1,103 @@ +import { Flex, Heading, Spacer } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useCallback } from 'react'; +import IAISwitch from 'common/components/IAISwitch'; +import { + asInitialImageToggled, + batchReset, + isEnabledChanged, +} from 'features/batch/store/batchSlice'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import BatchImageContainer from './BatchImageGrid'; +import { map } from 'lodash-es'; +import BatchControlNet from './BatchControlNet'; + +const selector = createSelector( + stateSelector, + (state) => { + const { controlNets } = state.controlNet; + const { + imageNames, + asInitialImage, + controlNets: batchControlNets, + isEnabled, + } = state.batch; + + return { + imageCount: imageNames.length, + asInitialImage, + controlNets, + batchControlNets, + isEnabled, + }; + }, + defaultSelectorOptions +); + +const BatchManager = () => { + const dispatch = useAppDispatch(); + const { imageCount, isEnabled, controlNets, batchControlNets } = + useAppSelector(selector); + + const handleResetBatch = useCallback(() => { + dispatch(batchReset()); + }, [dispatch]); + + const handleToggle = useCallback(() => { + dispatch(isEnabledChanged(!isEnabled)); + }, [dispatch, isEnabled]); + + const handleChangeAsInitialImage = useCallback(() => { + dispatch(asInitialImageToggled()); + }, [dispatch]); + + return ( + + + + {imageCount || 'No'} images + + + Reset + + + + {map(controlNets, (controlNet) => { + return ( + + ); + })} + + + + ); +}; + +export default BatchManager; diff --git a/invokeai/frontend/web/src/features/batch/store/batchSlice.ts b/invokeai/frontend/web/src/features/batch/store/batchSlice.ts new file mode 100644 index 0000000000..6a96361d3f --- /dev/null +++ b/invokeai/frontend/web/src/features/batch/store/batchSlice.ts @@ -0,0 +1,142 @@ +import { PayloadAction, createAction, createSlice } from '@reduxjs/toolkit'; +import { uniq } from 'lodash-es'; +import { imageDeleted } from 'services/api/thunks/image'; + +type BatchState = { + isEnabled: boolean; + imageNames: string[]; + asInitialImage: boolean; + controlNets: string[]; + selection: string[]; +}; + +export const initialBatchState: BatchState = { + isEnabled: false, + imageNames: [], + asInitialImage: false, + controlNets: [], + selection: [], +}; + +const batch = createSlice({ + name: 'batch', + initialState: initialBatchState, + reducers: { + isEnabledChanged: (state, action: PayloadAction) => { + state.isEnabled = action.payload; + }, + imageAddedToBatch: (state, action: PayloadAction) => { + state.imageNames = uniq(state.imageNames.concat(action.payload)); + }, + imagesAddedToBatch: (state, action: PayloadAction) => { + state.imageNames = uniq(state.imageNames.concat(action.payload)); + }, + imageRemovedFromBatch: (state, action: PayloadAction) => { + state.imageNames = state.imageNames.filter( + (imageName) => action.payload !== imageName + ); + state.selection = state.selection.filter( + (imageName) => action.payload !== imageName + ); + }, + imagesRemovedFromBatch: (state, action: PayloadAction) => { + state.imageNames = state.imageNames.filter( + (imageName) => !action.payload.includes(imageName) + ); + state.selection = state.selection.filter( + (imageName) => !action.payload.includes(imageName) + ); + }, + batchImageRangeEndSelected: (state, action: PayloadAction) => { + const rangeEndImageName = action.payload; + const lastSelectedImage = state.selection[state.selection.length - 1]; + const lastClickedIndex = state.imageNames.findIndex( + (n) => n === lastSelectedImage + ); + const currentClickedIndex = state.imageNames.findIndex( + (n) => n === rangeEndImageName + ); + if (lastClickedIndex > -1 && currentClickedIndex > -1) { + // We have a valid range! + const start = Math.min(lastClickedIndex, currentClickedIndex); + const end = Math.max(lastClickedIndex, currentClickedIndex); + + const imagesToSelect = state.imageNames.slice(start, end + 1); + state.selection = uniq(state.selection.concat(imagesToSelect)); + } + }, + batchImageSelectionToggled: (state, action: PayloadAction) => { + if ( + state.selection.includes(action.payload) && + state.selection.length > 1 + ) { + state.selection = state.selection.filter( + (imageName) => imageName !== action.payload + ); + } else { + state.selection = uniq(state.selection.concat(action.payload)); + } + }, + batchImageSelected: (state, action: PayloadAction) => { + state.selection = action.payload + ? [action.payload] + : [String(state.imageNames[0])]; + }, + batchReset: (state) => { + state.imageNames = []; + state.selection = []; + }, + asInitialImageToggled: (state) => { + state.asInitialImage = !state.asInitialImage; + }, + controlNetAddedToBatch: (state, action: PayloadAction) => { + state.controlNets = uniq(state.controlNets.concat(action.payload)); + }, + controlNetRemovedFromBatch: (state, action: PayloadAction) => { + state.controlNets = state.controlNets.filter( + (controlNetId) => controlNetId !== action.payload + ); + }, + controlNetToggled: (state, action: PayloadAction) => { + if (state.controlNets.includes(action.payload)) { + state.controlNets = state.controlNets.filter( + (controlNetId) => controlNetId !== action.payload + ); + } else { + state.controlNets = uniq(state.controlNets.concat(action.payload)); + } + }, + }, + extraReducers: (builder) => { + builder.addCase(imageDeleted.fulfilled, (state, action) => { + state.imageNames = state.imageNames.filter( + (imageName) => imageName !== action.meta.arg.image_name + ); + state.selection = state.selection.filter( + (imageName) => imageName !== action.meta.arg.image_name + ); + }); + }, +}); + +export const { + isEnabledChanged, + imageAddedToBatch, + imagesAddedToBatch, + imageRemovedFromBatch, + imagesRemovedFromBatch, + asInitialImageToggled, + controlNetAddedToBatch, + controlNetRemovedFromBatch, + batchReset, + controlNetToggled, + batchImageRangeEndSelected, + batchImageSelectionToggled, + batchImageSelected, +} = batch.actions; + +export default batch.reducer; + +export const selectionAddedToBatch = createAction( + 'batch/selectionAddedToBatch' +); diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index 36d82dc2ee..dde449a464 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -1,20 +1,22 @@ -import { memo, useCallback, useState } from 'react'; -import { ImageDTO } from 'services/api/types'; +import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'app/components/ImageDnd/typesafeDnd'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { PostUploadAction } from 'services/api/thunks/image'; import { ControlNetConfig, controlNetImageChanged, controlNetSelector, } from '../store/controlNetSlice'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { FaUndo } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { skipToken } from '@reduxjs/toolkit/dist/query'; const selector = createSelector( controlNetSelector, @@ -57,22 +59,6 @@ const ControlNetImagePreview = (props: Props) => { isSuccess: isSuccessProcessedControlImage, } = useGetImageDTOQuery(processedControlImageName ?? skipToken); - const handleDrop = useCallback( - (droppedImage: ImageDTO) => { - if (controlImageName === droppedImage.image_name) { - return; - } - setIsMouseOverImage(false); - dispatch( - controlNetImageChanged({ - controlNetId, - controlImage: droppedImage.image_name, - }) - ); - }, - [controlImageName, controlNetId, dispatch] - ); - const handleResetControlImage = useCallback(() => { dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); }, [controlNetId, dispatch]); @@ -84,6 +70,30 @@ const ControlNetImagePreview = (props: Props) => { setIsMouseOverImage(false); }, []); + const draggableData = useMemo(() => { + if (controlImage) { + return { + id: controlNetId, + payloadType: 'IMAGE_DTO', + payload: { imageDTO: controlImage }, + }; + } + }, [controlImage, controlNetId]); + + const droppableData = useMemo( + () => ({ + id: controlNetId, + actionType: 'SET_CONTROLNET_IMAGE', + context: { controlNetId }, + }), + [controlNetId] + ); + + const postUploadAction = useMemo( + () => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), + [controlNetId] + ); + const shouldShowProcessedImage = controlImage && processedControlImage && @@ -104,14 +114,14 @@ const ControlNetImagePreview = (props: Props) => { }} > { }} > {pendingControlImages.includes(controlNetId) && ( @@ -145,27 +154,12 @@ const ControlNetImagePreview = (props: Props) => { insetInlineStart: 0, w: 'full', h: 'full', + objectFit: 'contain', }} > - + )} - {controlImage && ( - - } - variant="link" - sx={{ - p: 2, - color: 'base.50', - }} - /> - - )} ); }; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx new file mode 100644 index 0000000000..3a7eea2fbf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx @@ -0,0 +1,36 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISwitch from 'common/components/IAISwitch'; +import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice'; +import { useCallback } from 'react'; + +const selector = createSelector( + stateSelector, + (state) => { + const { isEnabled } = state.controlNet; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamControlNetFeatureToggle = () => { + const { isEnabled } = useAppSelector(selector); + const dispatch = useAppDispatch(); + + const handleChange = useCallback(() => { + dispatch(isControlNetEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamControlNetFeatureToggle; diff --git a/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts new file mode 100644 index 0000000000..4bff39db63 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts @@ -0,0 +1,15 @@ +import { filter } from 'lodash-es'; +import { ControlNetConfig } from '../store/controlNetSlice'; + +export const getValidControlNets = ( + controlNets: Record +) => { + const validControlNets = filter( + controlNets, + (c) => + c.isEnabled && + (Boolean(c.processedControlImage) || + (c.processorType === 'none' && Boolean(c.controlImage))) + ); + return validControlNets; +}; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx index 1aefecf3e6..0e41fad994 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx @@ -1,40 +1,30 @@ +import { Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { useCallback } from 'react'; -import { isEnabledToggled } from '../store/slice'; -import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; -import { Flex } from '@chakra-ui/react'; +import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled'; +import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; const selector = createSelector( stateSelector, (state) => { const { isEnabled } = state.dynamicPrompts; - return { isEnabled }; + return { activeLabel: isEnabled ? 'Enabled' : undefined }; }, defaultSelectorOptions ); const ParamDynamicPromptsCollapse = () => { - const dispatch = useAppDispatch(); - const { isEnabled } = useAppSelector(selector); - - const handleToggleIsEnabled = useCallback(() => { - dispatch(isEnabledToggled()); - }, [dispatch]); + const { activeLabel } = useAppSelector(selector); return ( - + + diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx index 30c2240c37..cb930acd3b 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx @@ -1,23 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { combinatorialToggled } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; +import { useCallback } from 'react'; +import { combinatorialToggled } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { combinatorial } = state.dynamicPrompts; + const { combinatorial, isEnabled } = state.dynamicPrompts; - return { combinatorial }; + return { combinatorial, isDisabled: !isEnabled }; }, defaultSelectorOptions ); const ParamDynamicPromptsCombinatorial = () => { - const { combinatorial } = useAppSelector(selector); + const { combinatorial, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback(() => { @@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => { return ( { + const { isEnabled } = state.dynamicPrompts; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamDynamicPromptsToggle = () => { + const dispatch = useAppDispatch(); + const { isEnabled } = useAppSelector(selector); + + const handleToggleIsEnabled = useCallback(() => { + dispatch(isEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamDynamicPromptsToggle; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index 19f02ae3e5..172120fd1e 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -1,25 +1,31 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { useCallback } from 'react'; +import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { maxPrompts, combinatorial } = state.dynamicPrompts; + const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts; const { min, sliderMax, inputMax } = state.config.sd.dynamicPrompts.maxPrompts; - return { maxPrompts, min, sliderMax, inputMax, combinatorial }; + return { + maxPrompts, + min, + sliderMax, + inputMax, + isDisabled: !isEnabled || !combinatorial, + }; }, defaultSelectorOptions ); const ParamDynamicPromptsMaxPrompts = () => { - const { maxPrompts, min, sliderMax, inputMax, combinatorial } = + const { maxPrompts, min, sliderMax, inputMax, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); @@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => { return ( void; +}; + +const AddEmbeddingButton = (props: Props) => { + const { onClick } = props; + return ( + } + sx={{ + p: 2, + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + }} + variant="link" + onClick={onClick} + /> + ); +}; + +export default memo(AddEmbeddingButton); diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx new file mode 100644 index 0000000000..3c2ded0166 --- /dev/null +++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx @@ -0,0 +1,151 @@ +import { + Flex, + Popover, + PopoverBody, + PopoverContent, + PopoverTrigger, + Text, +} from '@chakra-ui/react'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { + PropsWithChildren, + forwardRef, + useCallback, + useMemo, + useRef, +} from 'react'; +import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; + +type EmbeddingSelectItem = { + label: string; + value: string; + description?: string; +}; + +type Props = PropsWithChildren & { + onSelect: (v: string) => void; + isOpen: boolean; + onClose: () => void; +}; + +const ParamEmbeddingPopover = (props: Props) => { + const { onSelect, isOpen, onClose, children } = props; + const { data: embeddingQueryData } = useGetTextualInversionModelsQuery(); + const inputRef = useRef(null); + + const data = useMemo(() => { + if (!embeddingQueryData) { + return []; + } + + const data: EmbeddingSelectItem[] = []; + + forEach(embeddingQueryData.entities, (embedding, _) => { + if (!embedding) return; + + data.push({ + value: embedding.name, + label: embedding.name, + description: embedding.description, + }); + }); + + return data; + }, [embeddingQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + if (v.length === 0) { + return; + } + + onSelect(v[0]); + }, + [onSelect] + ); + + return ( + + {children} + + + {data.length === 0 ? ( + + + No Embeddings Loaded + + + ) : ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + )} + + + + ); +}; + +export default ParamEmbeddingPopover; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; diff --git a/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts b/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx index 858329ead6..918e9390f9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx @@ -1,16 +1,16 @@ -import { Flex, Text, useColorMode } from '@chakra-ui/react'; +import { Flex, useColorMode } from '@chakra-ui/react'; import { FaImages } from 'react-icons/fa'; -import { boardIdSelected } from '../../store/boardSlice'; +import { boardIdSelected } from 'features/gallery/store/gallerySlice'; import { useDispatch } from 'react-redux'; -import { IAINoImageFallback } from 'common/components/IAIImageFallback'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { AnimatePresence } from 'framer-motion'; -import { SelectedItemOverlay } from '../SelectedItemOverlay'; -import { useCallback } from 'react'; -import { ImageDTO } from 'services/api/types'; -import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; -import { useDroppable } from '@dnd-kit/core'; import IAIDropOverlay from 'common/components/IAIDropOverlay'; import { mode } from 'theme/util/mode'; +import { + MoveBoardDropData, + isValidDrop, + useDroppable, +} from 'app/components/ImageDnd/typesafeDnd'; const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { const dispatch = useDispatch(); @@ -20,31 +20,15 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { dispatch(boardIdSelected()); }; - const [removeImageFromBoard, { isLoading }] = - useRemoveImageFromBoardMutation(); + const droppableData: MoveBoardDropData = { + id: 'all-images-board', + actionType: 'MOVE_BOARD', + context: { boardId: null }, + }; - const handleDrop = useCallback( - (droppedImage: ImageDTO) => { - if (!droppedImage.board_id) { - return; - } - removeImageFromBoard({ - board_id: droppedImage.board_id, - image_name: droppedImage.image_name, - }); - }, - [removeImageFromBoard] - ); - - const { - isOver, - setNodeRef, - active: isDropActive, - } = useDroppable({ + const { isOver, setNodeRef, active } = useDroppable({ id: `board_droppable_all_images`, - data: { - handleDrop, - }, + data: droppableData, }); return ( @@ -58,10 +42,10 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { h: 'full', borderRadius: 'base', }} - onClick={handleAllImagesBoardClick} > { borderRadius: 'base', w: 'full', aspectRatio: '1/1', + overflow: 'hidden', + shadow: isSelected ? 'selected.light' : undefined, + _dark: { shadow: isSelected ? 'selected.dark' : undefined }, + flexShrink: 0, }} > - + - {isSelected && } - - - {isDropActive && } + {isValidDrop(droppableData, active) && ( + + )} - { }} > All Images - + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx index fb095b9f42..5618c5c5c2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx @@ -2,6 +2,7 @@ import { Collapse, Flex, Grid, + GridItem, IconButton, Input, InputGroup, @@ -10,10 +11,7 @@ import { import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { - boardsSelector, - setBoardSearchText, -} from 'features/gallery/store/boardSlice'; +import { setBoardSearchText } from 'features/gallery/store/boardSlice'; import { memo, useState } from 'react'; import HoverableBoard from './HoverableBoard'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; @@ -21,11 +19,13 @@ import AddBoardButton from './AddBoardButton'; import AllImagesBoard from './AllImagesBoard'; import { CloseIcon } from '@chakra-ui/icons'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { stateSelector } from 'app/store/store'; const selector = createSelector( - [boardsSelector], - (boardsState) => { - const { selectedBoardId, searchText } = boardsState; + [stateSelector], + ({ boards, gallery }) => { + const { searchText } = boards; + const { selectedBoardId } = gallery; return { selectedBoardId, searchText }; }, defaultSelectorOptions @@ -109,20 +109,24 @@ const BoardsList = (props: Props) => { - {!searchMode && } + {!searchMode && ( + + + + )} {filteredBoards && filteredBoards.map((board) => ( - + + + ))} diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx index 118484f305..035ee77f18 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx @@ -15,10 +15,9 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { memo, useCallback, useContext } from 'react'; import { FaFolder, FaTrash } from 'react-icons/fa'; import { ContextMenu } from 'chakra-ui-contextmenu'; -import { BoardDTO, ImageDTO } from 'services/api/types'; -import { IAINoImageFallback } from 'common/components/IAIImageFallback'; -import { boardIdSelected } from 'features/gallery/store/boardSlice'; -import { useAddImageToBoardMutation } from 'services/api/endpoints/boardImages'; +import { BoardDTO } from 'services/api/types'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { boardIdSelected } from 'features/gallery/store/gallerySlice'; import { useDeleteBoardMutation, useUpdateBoardMutation, @@ -26,12 +25,15 @@ import { import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { useDroppable } from '@dnd-kit/core'; import { AnimatePresence } from 'framer-motion'; import IAIDropOverlay from 'common/components/IAIDropOverlay'; -import { SelectedItemOverlay } from '../SelectedItemOverlay'; import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext'; import { mode } from 'theme/util/mode'; +import { + MoveBoardDropData, + isValidDrop, + useDroppable, +} from 'app/components/ImageDnd/typesafeDnd'; interface HoverableBoardProps { board: BoardDTO; @@ -61,9 +63,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { const [deleteBoard, { isLoading: isDeleteBoardLoading }] = useDeleteBoardMutation(); - const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] = - useAddImageToBoardMutation(); - const handleUpdateBoardName = (newBoardName: string) => { updateBoard({ board_id, changes: { board_name: newBoardName } }); }; @@ -77,29 +76,19 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { onClickDeleteBoardImages(board); }, [board, onClickDeleteBoardImages]); - const handleDrop = useCallback( - (droppedImage: ImageDTO) => { - if (droppedImage.board_id === board_id) { - return; - } - addImageToBoard({ board_id, image_name: droppedImage.image_name }); - }, - [addImageToBoard, board_id] - ); + const droppableData: MoveBoardDropData = { + id: board_id, + actionType: 'MOVE_BOARD', + context: { boardId: board_id }, + }; - const { - isOver, - setNodeRef, - active: isDropActive, - } = useDroppable({ + const { isOver, setNodeRef, active } = useDroppable({ id: `board_droppable_${board_id}`, - data: { - handleDrop, - }, + data: droppableData, }); return ( - + menuProps={{ size: 'sm', isLazy: true }} renderMenu={() => ( @@ -148,13 +137,25 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { w: 'full', aspectRatio: '1/1', overflow: 'hidden', + shadow: isSelected ? 'selected.light' : undefined, + _dark: { shadow: isSelected ? 'selected.dark' : undefined }, + flexShrink: 0, }} > {board.cover_image_name && coverImage?.image_url && ( )} {!(board.cover_image_name && coverImage?.image_url) && ( - + )} { {board.image_count} - {isSelected && } - - - {isDropActive && } + {isValidDrop(droppableData, active) && ( + + )} - + { }} /> - + )} diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index 169a965be0..b4a3296f04 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -38,8 +38,7 @@ import { FaShare, FaShareAlt, } from 'react-icons/fa'; -import { gallerySelector } from '../store/gallerySelectors'; -import { useCallback, useContext } from 'react'; +import { useCallback } from 'react'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; @@ -49,22 +48,15 @@ import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceR import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings'; import { useAppToaster } from 'app/components/Toaster'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; -import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; -import { DeleteImageButton } from './DeleteImageModal'; -import { selectImagesById } from '../store/imagesSlice'; -import { RootState } from 'app/store/store'; +import { stateSelector } from 'app/store/store'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton'; const currentImageButtonsSelector = createSelector( - [ - (state: RootState) => state, - systemSelector, - gallerySelector, - postprocessingSelector, - uiSelector, - lightboxSelector, - activeTabNameSelector, - ], - (state, system, gallery, postprocessing, ui, lightbox, activeTabName) => { + [stateSelector, activeTabNameSelector], + ({ gallery, system, postprocessing, ui, lightbox }, activeTabName) => { const { isProcessing, isConnected, @@ -84,9 +76,7 @@ const currentImageButtonsSelector = createSelector( shouldShowProgressInViewer, } = ui; - const imageDTO = selectImagesById(state, gallery.selectedImage ?? ''); - - const { selectedImage } = gallery; + const lastSelectedImage = gallery.selection[gallery.selection.length - 1]; return { canDeleteImage: isConnected && !isProcessing, @@ -97,16 +87,13 @@ const currentImageButtonsSelector = createSelector( isESRGANAvailable, upscalingLevel, facetoolStrength, - shouldDisableToolbarButtons: Boolean(progressImage) || !selectedImage, + shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage, shouldShowImageDetails, activeTabName, isLightboxOpen, shouldHidePreview, - image: imageDTO, - seed: imageDTO?.metadata?.seed, - prompt: imageDTO?.metadata?.positive_conditioning, - negativePrompt: imageDTO?.metadata?.negative_conditioning, shouldShowProgressInViewer, + lastSelectedImage, }; }, { @@ -132,7 +119,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { isLightboxOpen, activeTabName, shouldHidePreview, - image, + lastSelectedImage, shouldShowProgressInViewer, } = useAppSelector(currentImageButtonsSelector); @@ -147,7 +134,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { const { recallBothPrompts, recallSeed, recallAllParameters } = useRecallParameters(); - const { onDelete } = useContext(DeleteImageContext); + const { currentData: image } = useGetImageDTOQuery( + lastSelectedImage ?? skipToken + ); // const handleCopyImage = useCallback(async () => { // if (!image?.url) { @@ -248,8 +237,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { }, []); const handleDelete = useCallback(() => { - onDelete(image); - }, [image, onDelete]); + if (!image) { + return; + } + dispatch(imageToDeleteSelected(image)); + }, [dispatch, image]); useHotkeys( 'Shift+U', @@ -371,7 +363,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { }} {...props} > - + { } isChecked={isLightboxOpen} onClick={handleLightBox} + isDisabled={shouldDisableToolbarButtons} /> )} - + } tooltip={`${t('parameters.usePrompt')} (P)`} @@ -478,7 +471,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { {(isUpscalingEnabled || isFaceRestoreEnabled) && ( - + {isFaceRestoreEnabled && ( { )} - + } tooltip={`${t('parameters.info')} (I)`} @@ -553,7 +549,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { /> - + { - + diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageDisplay.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageDisplay.tsx index 2da5185fe5..1d8863f4d8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageDisplay.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageDisplay.tsx @@ -1,29 +1,9 @@ import { Flex } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import { gallerySelector } from '../store/gallerySelectors'; import CurrentImageButtons from './CurrentImageButtons'; import CurrentImagePreview from './CurrentImagePreview'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; - -export const currentImageDisplaySelector = createSelector( - [systemSelector, gallerySelector], - (system, gallery) => { - const { progressImage } = system; - - return { - hasSelectedImage: Boolean(gallery.selectedImage), - hasProgressImage: Boolean(progressImage), - }; - }, - defaultSelectorOptions -); const CurrentImageDisplay = () => { - const { hasSelectedImage } = useAppSelector(currentImageDisplaySelector); - return ( { justifyContent: 'center', }} > - {hasSelectedImage && } + ); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index fac19b347e..8018beea9a 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -1,35 +1,33 @@ import { Box, Flex, Image } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { uiSelector } from 'features/ui/store/uiSelectors'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice'; import { isEqual } from 'lodash-es'; - -import { gallerySelector } from '../store/gallerySelectors'; +import { memo, useMemo } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import NextPrevImageButtons from './NextPrevImageButtons'; -import { memo, useCallback } from 'react'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import { imageSelected } from '../store/gallerySlice'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { ImageDTO } from 'services/api/types'; -import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { skipToken } from '@reduxjs/toolkit/dist/query'; export const imagesSelector = createSelector( - [uiSelector, gallerySelector, systemSelector], - (ui, gallery, system) => { + [stateSelector, selectLastSelectedImage], + ({ ui, system }, lastSelectedImage) => { const { shouldShowImageDetails, shouldHidePreview, shouldShowProgressInViewer, } = ui; - const { selectedImage } = gallery; const { progressImage, shouldAntialiasProgressImage } = system; return { shouldShowImageDetails, shouldHidePreview, - selectedImage, + imageName: lastSelectedImage, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, @@ -45,29 +43,35 @@ export const imagesSelector = createSelector( const CurrentImagePreview = () => { const { shouldShowImageDetails, - selectedImage, + imageName, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, } = useAppSelector(imagesSelector); const { - currentData: image, + currentData: imageDTO, isLoading, isError, isSuccess, - } = useGetImageDTOQuery(selectedImage ?? skipToken); + } = useGetImageDTOQuery(imageName ?? skipToken); - const dispatch = useAppDispatch(); + const draggableData = useMemo(() => { + if (imageDTO) { + return { + id: 'current-image', + payloadType: 'IMAGE_DTO', + payload: { imageDTO }, + }; + } + }, [imageDTO]); - const handleDrop = useCallback( - (droppedImage: ImageDTO) => { - if (droppedImage.image_name === image?.image_name) { - return; - } - dispatch(imageSelected(droppedImage.image_name)); - }, - [dispatch, image?.image_name] + const droppableData = useMemo( + () => ({ + id: 'current-image', + actionType: 'SET_CURRENT_IMAGE', + }), + [] ); return ( @@ -98,14 +102,15 @@ const CurrentImagePreview = () => { /> ) : ( } + imageDTO={imageDTO} + droppableData={droppableData} + draggableData={draggableData} isUploadDisabled={true} fitContainer + dropLabel="Set as Current Image" /> )} - {shouldShowImageDetails && image && ( + {shouldShowImageDetails && imageDTO && ( { overflow: 'scroll', }} > - + )} - {!shouldShowImageDetails && image && ( + {!shouldShowImageDetails && imageDTO && ( { - const { shouldConfirmOnDelete } = system; - const { canRestoreDeletedImagesFromBin } = config; - - return { - shouldConfirmOnDelete, - canRestoreDeletedImagesFromBin, - }; - }, - defaultSelectorOptions -); - -const ImageInUseMessage = (props: { imageUsage?: ImageUsage }) => { - const { imageUsage } = props; - - if (!imageUsage) { - return null; - } - - if (!some(imageUsage)) { - return null; - } - - return ( - <> - This image is currently in use in the following features: - - {imageUsage.isInitialImage && Image to Image} - {imageUsage.isCanvasImage && Unified Canvas} - {imageUsage.isControlNetImage && ControlNet} - {imageUsage.isNodesImage && Node Editor} - - - If you delete this image, those features will immediately be reset. - - - ); -}; - -const DeleteImageModal = () => { - const dispatch = useAppDispatch(); - const { t } = useTranslation(); - - const { isOpen, onClose, onImmediatelyDelete, image, imageUsage } = - useContext(DeleteImageContext); - - const { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin } = - useAppSelector(selector); - - const handleChangeShouldConfirmOnDelete = useCallback( - (e: ChangeEvent) => - dispatch(setShouldConfirmOnDelete(!e.target.checked)), - [dispatch] - ); - - const cancelRef = useRef(null); - - return ( - - - - - {t('gallery.deleteImage')} - - - - - - - - {canRestoreDeletedImagesFromBin - ? t('gallery.deleteImageBin') - : t('gallery.deleteImagePermanent')} - - {t('common.areYouSure')} - - - - - - Cancel - - - Delete - - - - - - ); -}; - -export default memo(DeleteImageModal); - -const deleteImageButtonsSelector = createSelector( - [systemSelector], - (system) => { - const { isProcessing, isConnected } = system; - - return isConnected && !isProcessing; - } -); - -type DeleteImageButtonProps = { - onClick: () => void; -}; - -export const DeleteImageButton = (props: DeleteImageButtonProps) => { - const { onClick } = props; - const { t } = useTranslation(); - const canDeleteImage = useAppSelector(deleteImageButtonsSelector); - - return ( - } - tooltip={`${t('gallery.deleteImage')} (Del)`} - aria-label={`${t('gallery.deleteImage')} (Del)`} - isDisabled={!canDeleteImage} - colorScheme="error" - /> - ); -}; diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx new file mode 100644 index 0000000000..a8d4c84adc --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx @@ -0,0 +1,132 @@ +import { Box } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaTrash } from 'react-icons/fa'; +import { ImageDTO } from 'services/api/types'; +import { + imageRangeEndSelected, + imageSelected, + imageSelectionToggled, +} from '../store/gallerySlice'; +import ImageContextMenu from './ImageContextMenu'; + +export const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + ({ gallery }) => { + const isSelected = gallery.selection.includes(image_name); + const selectionCount = gallery.selection.length; + + return { + isSelected, + selectionCount, + }; + }, + defaultSelectorOptions + ); + +interface HoverableImageProps { + imageDTO: ImageDTO; +} + +/** + * Gallery image component with delete/use all/use seed buttons on hover. + */ +const GalleryImage = (props: HoverableImageProps) => { + const { imageDTO } = props; + const { image_url, thumbnail_url, image_name } = imageDTO; + + const localSelector = useMemo(() => makeSelector(image_name), [image_name]); + + const { isSelected, selectionCount } = useAppSelector(localSelector); + + const dispatch = useAppDispatch(); + + const { t } = useTranslation(); + + const handleClick = useCallback( + (e: MouseEvent) => { + if (e.shiftKey) { + dispatch(imageRangeEndSelected(props.imageDTO.image_name)); + } else if (e.ctrlKey || e.metaKey) { + dispatch(imageSelectionToggled(props.imageDTO.image_name)); + } else { + dispatch(imageSelected(props.imageDTO.image_name)); + } + }, + [dispatch, props.imageDTO.image_name] + ); + + const handleDelete = useCallback( + (e: MouseEvent) => { + e.stopPropagation(); + if (!imageDTO) { + return; + } + dispatch(imageToDeleteSelected(imageDTO)); + }, + [dispatch, imageDTO] + ); + + const draggableData = useMemo(() => { + if (selectionCount > 1) { + return { + id: 'gallery-image', + payloadType: 'GALLERY_SELECTION', + }; + } + + if (imageDTO) { + return { + id: 'gallery-image', + payloadType: 'IMAGE_DTO', + payload: { imageDTO }, + }; + } + }, [imageDTO, selectionCount]); + + return ( + + + {(ref) => ( + + } + resetTooltip="Delete image" + imageSx={{ w: 'full', h: 'full' }} + // withResetIcon // removed bc it's too easy to accidentally delete images + isDropDisabled={true} + isUploadDisabled={true} + /> + + )} + + + ); +}; + +export default memo(GalleryImage); diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx deleted file mode 100644 index 91648d8df0..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx +++ /dev/null @@ -1,371 +0,0 @@ -import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { memo, useCallback, useContext, useState } from 'react'; -import { - FaCheck, - FaExpand, - FaFolder, - FaImage, - FaShare, - FaTrash, -} from 'react-icons/fa'; -import { ContextMenu } from 'chakra-ui-contextmenu'; -import { - resizeAndScaleCanvas, - setInitialCanvasImage, -} from 'features/canvas/store/canvasSlice'; -import { gallerySelector } from 'features/gallery/store/gallerySelectors'; -import { setActiveTab } from 'features/ui/store/uiSlice'; -import { useTranslation } from 'react-i18next'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { ExternalLinkIcon } from '@chakra-ui/icons'; -import { IoArrowUndoCircleOutline } from 'react-icons/io5'; -import { createSelector } from '@reduxjs/toolkit'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { isEqual } from 'lodash-es'; -import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { initialImageSelected } from 'features/parameters/store/actions'; -import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions'; -import { useAppToaster } from 'app/components/Toaster'; -import { ImageDTO } from 'services/api/types'; -import { useDraggable } from '@dnd-kit/core'; -import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; -import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext'; -import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; - -export const selector = createSelector( - [gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector], - (gallery, system, lightbox, activeTabName) => { - const { - galleryImageObjectFit, - galleryImageMinimumWidth, - shouldUseSingleGalleryColumn, - } = gallery; - - const { isLightboxOpen } = lightbox; - const { isConnected, isProcessing, shouldConfirmOnDelete } = system; - - return { - canDeleteImage: isConnected && !isProcessing, - shouldConfirmOnDelete, - galleryImageObjectFit, - galleryImageMinimumWidth, - shouldUseSingleGalleryColumn, - activeTabName, - isLightboxOpen, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -interface HoverableImageProps { - image: ImageDTO; - isSelected: boolean; -} - -/** - * Gallery image component with delete/use all/use seed buttons on hover. - */ -const HoverableImage = (props: HoverableImageProps) => { - const dispatch = useAppDispatch(); - const { - activeTabName, - galleryImageObjectFit, - galleryImageMinimumWidth, - canDeleteImage, - shouldUseSingleGalleryColumn, - } = useAppSelector(selector); - - const { image, isSelected } = props; - const { image_url, thumbnail_url, image_name } = image; - - const [isHovered, setIsHovered] = useState(false); - const toaster = useAppToaster(); - - const { t } = useTranslation(); - const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; - const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; - - const { onDelete } = useContext(DeleteImageContext); - const { onClickAddToBoard } = useContext(AddImageToBoardContext); - const handleDelete = useCallback(() => { - onDelete(image); - }, [image, onDelete]); - const { recallBothPrompts, recallSeed, recallAllParameters } = - useRecallParameters(); - - const { attributes, listeners, setNodeRef } = useDraggable({ - id: `galleryImage_${image_name}`, - data: { - image, - }, - }); - - const [removeFromBoard] = useRemoveImageFromBoardMutation(); - - const handleMouseOver = () => setIsHovered(true); - const handleMouseOut = () => setIsHovered(false); - - const handleSelectImage = useCallback(() => { - dispatch(imageSelected(image.image_name)); - }, [image, dispatch]); - - // Recall parameters handlers - const handleRecallPrompt = useCallback(() => { - recallBothPrompts( - image.metadata?.positive_conditioning, - image.metadata?.negative_conditioning - ); - }, [ - image.metadata?.negative_conditioning, - image.metadata?.positive_conditioning, - recallBothPrompts, - ]); - - const handleRecallSeed = useCallback(() => { - recallSeed(image.metadata?.seed); - }, [image, recallSeed]); - - const handleSendToImageToImage = useCallback(() => { - dispatch(sentImageToImg2Img()); - dispatch(initialImageSelected(image)); - }, [dispatch, image]); - - // const handleRecallInitialImage = useCallback(() => { - // recallInitialImage(image.metadata.invokeai?.node?.image); - // }, [image, recallInitialImage]); - - /** - * TODO: the rest of these - */ - const handleSendToCanvas = () => { - dispatch(sentImageToCanvas()); - dispatch(setInitialCanvasImage(image)); - - dispatch(resizeAndScaleCanvas()); - - if (activeTabName !== 'unifiedCanvas') { - dispatch(setActiveTab('unifiedCanvas')); - } - - toaster({ - title: t('toast.sentToUnifiedCanvas'), - status: 'success', - duration: 2500, - isClosable: true, - }); - }; - - const handleUseAllParameters = useCallback(() => { - recallAllParameters(image); - }, [image, recallAllParameters]); - - const handleLightBox = () => { - // dispatch(setCurrentImage(image)); - // dispatch(setIsLightboxOpen(true)); - }; - - const handleAddToBoard = useCallback(() => { - onClickAddToBoard(image); - }, [image, onClickAddToBoard]); - - const handleRemoveFromBoard = useCallback(() => { - if (!image.board_id) { - return; - } - removeFromBoard({ board_id: image.board_id, image_name: image.image_name }); - }, [image.board_id, image.image_name, removeFromBoard]); - - const handleOpenInNewTab = () => { - window.open(image.image_url, '_blank'); - }; - - return ( - - - menuProps={{ size: 'sm', isLazy: true }} - renderMenu={() => ( - - } - onClickCapture={handleOpenInNewTab} - > - {t('common.openInNewTab')} - - {isLightboxEnabled && ( - } onClickCapture={handleLightBox}> - {t('parameters.openInViewer')} - - )} - } - onClickCapture={handleRecallPrompt} - isDisabled={image?.metadata?.positive_conditioning === undefined} - > - {t('parameters.usePrompt')} - - - } - onClickCapture={handleRecallSeed} - isDisabled={image?.metadata?.seed === undefined} - > - {t('parameters.useSeed')} - - {/* } - onClickCapture={handleRecallInitialImage} - isDisabled={image?.metadata?.type !== 'img2img'} - > - {t('parameters.useInitImg')} - */} - } - onClickCapture={handleUseAllParameters} - isDisabled={ - // what should these be - !['t2l', 'l2l', 'inpaint'].includes( - String(image?.metadata?.type) - ) - } - > - {t('parameters.useAll')} - - } - onClickCapture={handleSendToImageToImage} - id="send-to-img2img" - > - {t('parameters.sendToImg2Img')} - - {isCanvasEnabled && ( - } - onClickCapture={handleSendToCanvas} - id="send-to-canvas" - > - {t('parameters.sendToUnifiedCanvas')} - - )} - } onClickCapture={handleAddToBoard}> - {image.board_id ? 'Change Board' : 'Add to Board'} - - {image.board_id && ( - } - onClickCapture={handleRemoveFromBoard} - > - Remove from Board - - )} - } - onClickCapture={handleDelete} - > - {t('gallery.deleteImage')} - - - )} - > - {(ref) => ( - - } - sx={{ - width: '100%', - height: '100%', - maxWidth: '100%', - maxHeight: '100%', - }} - /> - {isSelected && ( - - - - )} - {isHovered && galleryImageMinimumWidth >= 100 && ( - - } - size="xs" - fontSize={14} - isDisabled={!canDeleteImage} - /> - - )} - - )} - - - ); -}; - -export default memo(HoverableImage); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu.tsx new file mode 100644 index 0000000000..1e5f95ab0d --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu.tsx @@ -0,0 +1,278 @@ +import { MenuItem, MenuList } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { memo, useCallback, useContext } from 'react'; +import { + FaExpand, + FaFolder, + FaFolderPlus, + FaShare, + FaTrash, +} from 'react-icons/fa'; +import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; +import { + resizeAndScaleCanvas, + setInitialCanvasImage, +} from 'features/canvas/store/canvasSlice'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { useTranslation } from 'react-i18next'; +import { ExternalLinkIcon } from '@chakra-ui/icons'; +import { IoArrowUndoCircleOutline } from 'react-icons/io5'; +import { createSelector } from '@reduxjs/toolkit'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; +import { initialImageSelected } from 'features/parameters/store/actions'; +import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions'; +import { useAppToaster } from 'app/components/Toaster'; +import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext'; +import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages'; +import { ImageDTO } from 'services/api/types'; +import { RootState, stateSelector } from 'app/store/store'; +import { + imagesAddedToBatch, + selectionAddedToBatch, +} from 'features/batch/store/batchSlice'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; + +const selector = createSelector( + [stateSelector, (state: RootState, imageDTO: ImageDTO) => imageDTO], + ({ gallery, batch }, imageDTO) => { + const selectionCount = gallery.selection.length; + const isInBatch = batch.imageNames.includes(imageDTO.image_name); + + return { selectionCount, isInBatch }; + }, + defaultSelectorOptions +); + +type Props = { + image: ImageDTO; + children: ContextMenuProps['children']; +}; + +const ImageContextMenu = ({ image, children }: Props) => { + const { selectionCount, isInBatch } = useAppSelector((state) => + selector(state, image) + ); + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const toaster = useAppToaster(); + + const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; + const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; + + const { onClickAddToBoard } = useContext(AddImageToBoardContext); + + const handleDelete = useCallback(() => { + if (!image) { + return; + } + dispatch(imageToDeleteSelected(image)); + }, [dispatch, image]); + + const { recallBothPrompts, recallSeed, recallAllParameters } = + useRecallParameters(); + + const [removeFromBoard] = useRemoveImageFromBoardMutation(); + + // Recall parameters handlers + const handleRecallPrompt = useCallback(() => { + recallBothPrompts( + image.metadata?.positive_conditioning, + image.metadata?.negative_conditioning + ); + }, [ + image.metadata?.negative_conditioning, + image.metadata?.positive_conditioning, + recallBothPrompts, + ]); + + const handleRecallSeed = useCallback(() => { + recallSeed(image.metadata?.seed); + }, [image, recallSeed]); + + const handleSendToImageToImage = useCallback(() => { + dispatch(sentImageToImg2Img()); + dispatch(initialImageSelected(image)); + }, [dispatch, image]); + + // const handleRecallInitialImage = useCallback(() => { + // recallInitialImage(image.metadata.invokeai?.node?.image); + // }, [image, recallInitialImage]); + + const handleSendToCanvas = () => { + dispatch(sentImageToCanvas()); + dispatch(setInitialCanvasImage(image)); + dispatch(resizeAndScaleCanvas()); + dispatch(setActiveTab('unifiedCanvas')); + + toaster({ + title: t('toast.sentToUnifiedCanvas'), + status: 'success', + duration: 2500, + isClosable: true, + }); + }; + + const handleUseAllParameters = useCallback(() => { + recallAllParameters(image); + }, [image, recallAllParameters]); + + const handleLightBox = () => { + // dispatch(setCurrentImage(image)); + // dispatch(setIsLightboxOpen(true)); + }; + + const handleAddToBoard = useCallback(() => { + onClickAddToBoard(image); + }, [image, onClickAddToBoard]); + + const handleRemoveFromBoard = useCallback(() => { + if (!image.board_id) { + return; + } + removeFromBoard({ board_id: image.board_id, image_name: image.image_name }); + }, [image.board_id, image.image_name, removeFromBoard]); + + const handleOpenInNewTab = () => { + window.open(image.image_url, '_blank'); + }; + + const handleAddSelectionToBatch = useCallback(() => { + dispatch(selectionAddedToBatch()); + }, [dispatch]); + + const handleAddToBatch = useCallback(() => { + dispatch(imagesAddedToBatch([image.image_name])); + }, [dispatch, image.image_name]); + + return ( + + menuProps={{ size: 'sm', isLazy: true }} + renderMenu={() => ( + + {selectionCount === 1 ? ( + <> + } + onClickCapture={handleOpenInNewTab} + > + {t('common.openInNewTab')} + + {isLightboxEnabled && ( + } onClickCapture={handleLightBox}> + {t('parameters.openInViewer')} + + )} + } + onClickCapture={handleRecallPrompt} + isDisabled={ + image?.metadata?.positive_conditioning === undefined + } + > + {t('parameters.usePrompt')} + + + } + onClickCapture={handleRecallSeed} + isDisabled={image?.metadata?.seed === undefined} + > + {t('parameters.useSeed')} + + {/* } + onClickCapture={handleRecallInitialImage} + isDisabled={image?.metadata?.type !== 'img2img'} + > + {t('parameters.useInitImg')} + */} + } + onClickCapture={handleUseAllParameters} + isDisabled={ + // what should these be + !['t2l', 'l2l', 'inpaint'].includes( + String(image?.metadata?.type) + ) + } + > + {t('parameters.useAll')} + + } + onClickCapture={handleSendToImageToImage} + id="send-to-img2img" + > + {t('parameters.sendToImg2Img')} + + {isCanvasEnabled && ( + } + onClickCapture={handleSendToCanvas} + id="send-to-canvas" + > + {t('parameters.sendToUnifiedCanvas')} + + )} + {/* } + isDisabled={isInBatch} + onClickCapture={handleAddToBatch} + > + Add to Batch + */} + } onClickCapture={handleAddToBoard}> + {image.board_id ? 'Change Board' : 'Add to Board'} + + {image.board_id && ( + } + onClickCapture={handleRemoveFromBoard} + > + Remove from Board + + )} + } + onClickCapture={handleDelete} + > + {t('gallery.deleteImage')} + + + ) : ( + <> + } + onClickCapture={handleAddToBoard} + > + Move Selection to Board + + {/* } + onClickCapture={handleAddSelectionToBatch} + > + Add Selection to Batch + */} + } + onClickCapture={handleDelete} + > + Delete Selection + + + )} + + )} + > + {children} + + ); +}; + +export default memo(ImageContextMenu); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index a22eb6d20f..a5fc653913 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -5,7 +5,7 @@ import { Flex, FlexProps, Grid, - Icon, + Skeleton, Text, VStack, forwardRef, @@ -18,12 +18,8 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAIIconButton from 'common/components/IAIIconButton'; import IAIPopover from 'common/components/IAIPopover'; import IAISlider from 'common/components/IAISlider'; -import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { setGalleryImageMinimumWidth, - setGalleryImageObjectFit, - setShouldAutoSwitchToNewImages, - setShouldUseSingleGalleryColumn, setGalleryView, } from 'features/gallery/store/gallerySlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; @@ -42,77 +38,56 @@ import { import { useTranslation } from 'react-i18next'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { FaImage, FaServer, FaWrench } from 'react-icons/fa'; -import { MdPhotoLibrary } from 'react-icons/md'; -import HoverableImage from './HoverableImage'; +import GalleryImage from './GalleryImage'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; +import { RootState, stateSelector } from 'app/store/store'; +import { VirtuosoGrid } from 'react-virtuoso'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, imageCategoriesChanged, - selectImagesAll, -} from '../store/imagesSlice'; + shouldAutoSwitchChanged, + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; import { receivedPageOfImages } from 'services/api/thunks/image'; import BoardsList from './Boards/BoardsList'; -import { boardsSelector } from '../store/boardSlice'; import { ChevronUpIcon } from '@chakra-ui/icons'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { mode } from 'theme/util/mode'; +import { ImageDTO } from 'services/api/types'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -const itemSelector = createSelector( - [(state: RootState) => state], - (state) => { - const { categories, total: allImagesTotal, isLoading } = state.images; - const { selectedBoardId } = state.boards; +const LOADING_IMAGE_ARRAY = Array(20).fill('loading'); - const allImages = selectImagesAll(state); +const selector = createSelector( + [stateSelector, selectFilteredImages], + (state, filteredImages) => { + const { + categories, + total: allImagesTotal, + isLoading, + selectedBoardId, + galleryImageMinimumWidth, + galleryView, + shouldAutoSwitch, + } = state.gallery; + const { shouldPinGallery } = state.ui; - const images = allImages.filter((i) => { - const isInCategory = categories.includes(i.image_category); - const isInSelectedBoard = selectedBoardId - ? i.board_id === selectedBoardId - : true; - return isInCategory && isInSelectedBoard; - }); + const images = filteredImages as (ImageDTO | string)[]; return { - images, + images: isLoading ? images.concat(LOADING_IMAGE_ARRAY) : images, allImagesTotal, isLoading, categories, selectedBoardId, - }; - }, - defaultSelectorOptions -); - -const mainSelector = createSelector( - [gallerySelector, uiSelector, boardsSelector], - (gallery, ui, boards) => { - const { - galleryImageMinimumWidth, - galleryImageObjectFit, - shouldAutoSwitchToNewImages, - shouldUseSingleGalleryColumn, - selectedImage, - galleryView, - } = gallery; - - const { shouldPinGallery } = ui; - return { shouldPinGallery, galleryImageMinimumWidth, - galleryImageObjectFit, - shouldAutoSwitchToNewImages, - shouldUseSingleGalleryColumn, - selectedImage, + shouldAutoSwitch, galleryView, - selectedBoardId: boards.selectedBoardId, }; }, defaultSelectorOptions @@ -140,17 +115,16 @@ const ImageGalleryContent = () => { const { colorMode } = useColorMode(); const { + images, + isLoading, + allImagesTotal, + categories, + selectedBoardId, shouldPinGallery, galleryImageMinimumWidth, - galleryImageObjectFit, - shouldAutoSwitchToNewImages, - shouldUseSingleGalleryColumn, - selectedImage, + shouldAutoSwitch, galleryView, - } = useAppSelector(mainSelector); - - const { images, isLoading, allImagesTotal, categories, selectedBoardId } = - useAppSelector(itemSelector); + } = useAppSelector(selector); const { selectedBoard } = useListAllBoardsQuery(undefined, { selectFromResult: ({ data }) => ({ @@ -208,11 +182,14 @@ const ImageGalleryContent = () => { return () => osInstance()?.destroy(); }, [scroller, initialize, osInstance]); - const setScrollerRef = useCallback((ref: HTMLElement | Window | null) => { - if (ref instanceof HTMLElement) { - setScroller(ref); - } - }, []); + useEffect(() => { + dispatch( + receivedPageOfImages({ + categories: ['general'], + is_intermediate: false, + }) + ); + }, [dispatch]); const handleClickImagesCategory = useCallback(() => { dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); @@ -314,29 +291,11 @@ const ImageGalleryContent = () => { withReset handleReset={() => dispatch(setGalleryImageMinimumWidth(64))} /> - - dispatch( - setGalleryImageObjectFit( - galleryImageObjectFit === 'contain' ? 'cover' : 'contain' - ) - ) - } - /> ) => - dispatch(setShouldAutoSwitchToNewImages(e.target.checked)) - } - /> - ) => - dispatch(setShouldUseSingleGalleryColumn(e.target.checked)) + dispatch(shouldAutoSwitchChanged(e.target.checked)) } /> @@ -358,41 +317,28 @@ const ImageGalleryContent = () => { {images.length || areMoreAvailable ? ( <> - {shouldUseSingleGalleryColumn ? ( - setScrollerRef(ref)} - itemContent={(index, item) => ( - - - - )} - /> - ) : ( - ( - + typeof item === 'string' ? ( + - )} - /> - )} + ) : ( + + ) + } + /> { ) : ( - - - {t('gallery.noImagesInGallery')} - + )} @@ -436,7 +365,7 @@ const ImageGalleryContent = () => { type ItemContainerProps = PropsWithChildren & FlexProps; const ItemContainer = forwardRef((props: ItemContainerProps, ref) => ( - + {props.children} )); @@ -453,8 +382,7 @@ const ListContainer = forwardRef((props: ListContainerProps, ref) => { className="list-container" ref={ref} sx={{ - gap: 2, - gridTemplateColumns: `repeat(auto-fit, minmax(${galleryImageMinimumWidth}px, 1fr));`, + gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`, }} > {props.children} diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index b1f06ad433..69dc1b2b19 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -5,14 +5,13 @@ import { clamp, isEqual } from 'lodash-es'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { FaAngleLeft, FaAngleRight } from 'react-icons/fa'; -import { gallerySelector } from '../store/gallerySelectors'; -import { RootState } from 'app/store/store'; -import { imageSelected } from '../store/gallerySlice'; -import { useHotkeys } from 'react-hotkeys-hook'; +import { stateSelector } from 'app/store/store'; import { - selectFilteredImagesAsObject, - selectFilteredImagesIds, -} from '../store/imagesSlice'; + imageSelected, + selectImagesById, +} from 'features/gallery/store/gallerySlice'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { selectFilteredImages } from 'features/gallery/store/gallerySlice'; const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { height: '100%', @@ -25,45 +24,40 @@ const nextPrevButtonStyles: ChakraProps['sx'] = { }; export const nextPrevImageButtonsSelector = createSelector( - [ - (state: RootState) => state, - gallerySelector, - selectFilteredImagesAsObject, - selectFilteredImagesIds, - ], - (state, gallery, filteredImagesAsObject, filteredImageIds) => { - const { selectedImage } = gallery; + [stateSelector, selectFilteredImages], + (state, filteredImages) => { + const lastSelectedImage = + state.gallery.selection[state.gallery.selection.length - 1]; - if (!selectedImage) { + if (!lastSelectedImage || filteredImages.length === 0) { return { isOnFirstImage: true, isOnLastImage: true, }; } - const currentImageIndex = filteredImageIds.findIndex( - (i) => i === selectedImage + const currentImageIndex = filteredImages.findIndex( + (i) => i.image_name === lastSelectedImage ); - const nextImageIndex = clamp( currentImageIndex + 1, 0, - filteredImageIds.length - 1 + filteredImages.length - 1 ); const prevImageIndex = clamp( currentImageIndex - 1, 0, - filteredImageIds.length - 1 + filteredImages.length - 1 ); - const nextImageId = filteredImageIds[nextImageIndex]; - const prevImageId = filteredImageIds[prevImageIndex]; + const nextImageId = filteredImages[nextImageIndex].image_name; + const prevImageId = filteredImages[prevImageIndex].image_name; - const nextImage = filteredImagesAsObject[nextImageId]; - const prevImage = filteredImagesAsObject[prevImageId]; + const nextImage = selectImagesById(state, nextImageId); + const prevImage = selectImagesById(state, prevImageId); - const imagesLength = filteredImageIds.length; + const imagesLength = filteredImages.length; return { isOnFirstImage: currentImageIndex === 0, @@ -101,11 +95,11 @@ const NextPrevImageButtons = () => { }, []); const handlePrevImage = useCallback(() => { - dispatch(imageSelected(prevImageId)); + prevImageId && dispatch(imageSelected(prevImageId)); }, [dispatch, prevImageId]); const handleNextImage = useCallback(() => { - dispatch(imageSelected(nextImageId)); + nextImageId && dispatch(imageSelected(nextImageId)); }, [dispatch, nextImageId]); useHotkeys( diff --git a/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx deleted file mode 100644 index 3fabe706d6..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import { useColorMode, useToken } from '@chakra-ui/react'; -import { motion } from 'framer-motion'; -import { mode } from 'theme/util/mode'; - -export const SelectedItemOverlay = () => { - const [accent400, accent500] = useToken('colors', [ - 'accent.400', - 'accent.500', - ]); - - const { colorMode } = useColorMode(); - - return ( - - ); -}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts deleted file mode 100644 index 89709b322a..0000000000 --- a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { useAppSelector } from 'app/store/storeHooks'; -import { selectImagesEntities } from '../store/imagesSlice'; -import { useCallback } from 'react'; - -const useGetImageByName = () => { - const images = useAppSelector(selectImagesEntities); - return useCallback( - (name: string | undefined) => { - if (!name) { - return; - } - return images[name]; - }, - [images] - ); -}; - -export default useGetImageByName; diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts index 4234778120..0e1b1ef2a0 100644 --- a/invokeai/frontend/web/src/features/gallery/store/actions.ts +++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts @@ -1,15 +1,6 @@ import { createAction } from '@reduxjs/toolkit'; -import { ImageUsage } from 'app/contexts/DeleteImageContext'; -import { ImageDTO, BoardDTO } from 'services/api/types'; - -export type RequestedImageDeletionArg = { - image: ImageDTO; - imageUsage: ImageUsage; -}; - -export const requestedImageDeletion = createAction( - 'gallery/requestedImageDeletion' -); +import { ImageUsage } from 'app/contexts/AddImageToBoardContext'; +import { BoardDTO } from 'services/api/types'; export type RequestedBoardImagesDeletionArg = { board: BoardDTO; diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts index 7ec74dc4bf..e6b59eee9a 100644 --- a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts @@ -1,10 +1,8 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { boardsApi } from 'services/api/endpoints/boards'; type BoardsState = { searchText: string; - selectedBoardId?: string; updateBoardModalOpen: boolean; }; @@ -17,9 +15,6 @@ const boardsSlice = createSlice({ name: 'boards', initialState: initialBoardsState, reducers: { - boardIdSelected: (state, action: PayloadAction) => { - state.selectedBoardId = action.payload; - }, setBoardSearchText: (state, action: PayloadAction) => { state.searchText = action.payload; }, @@ -27,19 +22,9 @@ const boardsSlice = createSlice({ state.updateBoardModalOpen = action.payload; }, }, - extraReducers: (builder) => { - builder.addMatcher( - boardsApi.endpoints.deleteBoard.matchFulfilled, - (state, action) => { - if (action.meta.arg.originalArgs === state.selectedBoardId) { - state.selectedBoardId = undefined; - } - } - ); - }, }); -export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } = +export const { setBoardSearchText, setUpdateBoardModalOpen } = boardsSlice.actions; export const boardsSelector = (state: RootState) => state.boards; diff --git a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts index 44e03f9f71..201cffa70e 100644 --- a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts @@ -1,8 +1,15 @@ -import { GalleryState } from './gallerySlice'; +import { initialGalleryState } from './gallerySlice'; /** * Gallery slice persist denylist */ -export const galleryPersistDenylist: (keyof GalleryState)[] = [ - 'shouldAutoSwitchToNewImages', +export const galleryPersistDenylist: (keyof typeof initialGalleryState)[] = [ + 'selection', + 'entities', + 'ids', + 'isLoading', + 'limit', + 'offset', + 'selectedBoardId', + 'total', ]; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index b7fc0809a6..41a52e3452 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,87 +1,260 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import { imageUpserted } from './imagesSlice'; +import type { PayloadAction, Update } from '@reduxjs/toolkit'; +import { + createEntityAdapter, + createSelector, + createSlice, +} from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { dateComparator } from 'common/util/dateComparator'; +import { keyBy, uniq } from 'lodash-es'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { + imageUrlsReceived, + receivedPageOfImages, +} from 'services/api/thunks/image'; +import { ImageCategory, ImageDTO } from 'services/api/types'; -type GalleryImageObjectFitType = 'contain' | 'cover'; +export const imagesAdapter = createEntityAdapter({ + selectId: (image) => image.image_name, + sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at), +}); -export interface GalleryState { - selectedImage?: string; +export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; +export const ASSETS_CATEGORIES: ImageCategory[] = [ + 'control', + 'mask', + 'user', + 'other', +]; + +type AdditionaGalleryState = { + offset: number; + limit: number; + total: number; + isLoading: boolean; + categories: ImageCategory[]; + selectedBoardId?: string; + selection: string[]; + shouldAutoSwitch: boolean; galleryImageMinimumWidth: number; - galleryImageObjectFit: GalleryImageObjectFitType; - shouldAutoSwitchToNewImages: boolean; - shouldUseSingleGalleryColumn: boolean; galleryView: 'images' | 'assets' | 'boards'; -} - -export const initialGalleryState: GalleryState = { - galleryImageMinimumWidth: 64, - galleryImageObjectFit: 'cover', - shouldAutoSwitchToNewImages: true, - shouldUseSingleGalleryColumn: false, - galleryView: 'images', }; +export const initialGalleryState = + imagesAdapter.getInitialState({ + offset: 0, + limit: 0, + total: 0, + isLoading: true, + categories: IMAGE_CATEGORIES, + selection: [], + shouldAutoSwitch: true, + galleryImageMinimumWidth: 64, + galleryView: 'images', + }); + export const gallerySlice = createSlice({ name: 'gallery', initialState: initialGalleryState, reducers: { - imageSelected: (state, action: PayloadAction) => { - state.selectedImage = action.payload; - // TODO: if the user selects an image, disable the auto switch? - // state.shouldAutoSwitchToNewImages = false; + imageUpserted: (state, action: PayloadAction) => { + imagesAdapter.upsertOne(state, action.payload); + if ( + state.shouldAutoSwitch && + action.payload.image_category === 'general' + ) { + state.selection = [action.payload.image_name]; + } + }, + imageUpdatedOne: (state, action: PayloadAction>) => { + imagesAdapter.updateOne(state, action.payload); + }, + imageRemoved: (state, action: PayloadAction) => { + imagesAdapter.removeOne(state, action.payload); + }, + imagesRemoved: (state, action: PayloadAction) => { + imagesAdapter.removeMany(state, action.payload); + }, + imageCategoriesChanged: (state, action: PayloadAction) => { + state.categories = action.payload; + }, + imageRangeEndSelected: (state, action: PayloadAction) => { + const rangeEndImageName = action.payload; + const lastSelectedImage = state.selection[state.selection.length - 1]; + + const filteredImages = selectFilteredImagesLocal(state); + + const lastClickedIndex = filteredImages.findIndex( + (n) => n.image_name === lastSelectedImage + ); + + const currentClickedIndex = filteredImages.findIndex( + (n) => n.image_name === rangeEndImageName + ); + + if (lastClickedIndex > -1 && currentClickedIndex > -1) { + // We have a valid range! + const start = Math.min(lastClickedIndex, currentClickedIndex); + const end = Math.max(lastClickedIndex, currentClickedIndex); + + const imagesToSelect = filteredImages + .slice(start, end + 1) + .map((i) => i.image_name); + + state.selection = uniq(state.selection.concat(imagesToSelect)); + } + }, + imageSelectionToggled: (state, action: PayloadAction) => { + if ( + state.selection.includes(action.payload) && + state.selection.length > 1 + ) { + state.selection = state.selection.filter( + (imageName) => imageName !== action.payload + ); + } else { + state.selection = uniq(state.selection.concat(action.payload)); + } + }, + imageSelected: (state, action: PayloadAction) => { + state.selection = action.payload + ? [action.payload] + : [String(state.ids[0])]; + }, + shouldAutoSwitchChanged: (state, action: PayloadAction) => { + state.shouldAutoSwitch = action.payload; }, setGalleryImageMinimumWidth: (state, action: PayloadAction) => { state.galleryImageMinimumWidth = action.payload; }, - setGalleryImageObjectFit: ( - state, - action: PayloadAction - ) => { - state.galleryImageObjectFit = action.payload; - }, - setShouldAutoSwitchToNewImages: (state, action: PayloadAction) => { - state.shouldAutoSwitchToNewImages = action.payload; - }, - setShouldUseSingleGalleryColumn: ( - state, - action: PayloadAction - ) => { - state.shouldUseSingleGalleryColumn = action.payload; - }, setGalleryView: ( state, action: PayloadAction<'images' | 'assets' | 'boards'> ) => { state.galleryView = action.payload; }, + boardIdSelected: (state, action: PayloadAction) => { + state.selectedBoardId = action.payload; + }, }, extraReducers: (builder) => { - builder.addCase(imageUpserted, (state, action) => { - if ( - state.shouldAutoSwitchToNewImages && - action.payload.image_category === 'general' - ) { - state.selectedImage = action.payload.image_name; - } + builder.addCase(receivedPageOfImages.pending, (state) => { + state.isLoading = true; }); - // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - // const { image_name, image_url, thumbnail_url } = action.payload; + builder.addCase(receivedPageOfImages.rejected, (state) => { + state.isLoading = false; + }); + builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { + state.isLoading = false; + const { board_id, categories, image_origin, is_intermediate } = + action.meta.arg; - // if (state.selectedImage?.image_name === image_name) { - // state.selectedImage.image_url = image_url; - // state.selectedImage.thumbnail_url = thumbnail_url; - // } - // }); + const { items, offset, limit, total } = action.payload; + + const transformedItems = items.map((item) => ({ + ...item, + isSelected: false, + })); + + imagesAdapter.upsertMany(state, transformedItems); + + if (state.selection.length === 0) { + state.selection = [items[0].image_name]; + } + + if (!categories?.includes('general') || board_id) { + // need to skip updating the total images count if the images recieved were for a specific board + // TODO: this doesn't work when on the Asset tab/category... + return; + } + + state.offset = offset; + state.limit = limit; + state.total = total; + }); + builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + const { image_name, image_url, thumbnail_url } = action.payload; + + imagesAdapter.updateOne(state, { + id: image_name, + changes: { image_url, thumbnail_url }, + }); + }); + builder.addMatcher( + boardsApi.endpoints.deleteBoard.matchFulfilled, + (state, action) => { + if (action.meta.arg.originalArgs === state.selectedBoardId) { + state.selectedBoardId = undefined; + } + } + ); }, }); export const { + selectAll: selectImagesAll, + selectById: selectImagesById, + selectEntities: selectImagesEntities, + selectIds: selectImagesIds, + selectTotal: selectImagesTotal, +} = imagesAdapter.getSelectors((state) => state.gallery); + +export const { + imageUpserted, + imageUpdatedOne, + imageRemoved, + imagesRemoved, + imageCategoriesChanged, + imageRangeEndSelected, + imageSelectionToggled, imageSelected, + shouldAutoSwitchChanged, setGalleryImageMinimumWidth, - setGalleryImageObjectFit, - setShouldAutoSwitchToNewImages, - setShouldUseSingleGalleryColumn, setGalleryView, + boardIdSelected, } = gallerySlice.actions; export default gallerySlice.reducer; + +export const selectFilteredImagesLocal = createSelector( + (state: typeof initialGalleryState) => state, + (galleryState) => { + const allImages = imagesAdapter.getSelectors().selectAll(galleryState); + const { categories, selectedBoardId } = galleryState; + + const filteredImages = allImages.filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = selectedBoardId + ? i.board_id === selectedBoardId + : true; + return isInCategory && isInSelectedBoard; + }); + + return filteredImages; + } +); + +export const selectFilteredImages = createSelector( + (state: RootState) => state, + (state) => { + return selectFilteredImagesLocal(state.gallery); + }, + defaultSelectorOptions +); + +export const selectFilteredImagesAsObject = createSelector( + selectFilteredImages, + (filteredImages) => keyBy(filteredImages, 'image_name') +); + +export const selectFilteredImagesIds = createSelector( + selectFilteredImages, + (filteredImages) => filteredImages.map((i) => i.image_name) +); + +export const selectLastSelectedImage = createSelector( + (state: RootState) => state, + (state) => state.gallery.selection[state.gallery.selection.length - 1], + defaultSelectorOptions +); diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts deleted file mode 100644 index 8041ffd5c5..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts +++ /dev/null @@ -1,182 +0,0 @@ -import { - PayloadAction, - Update, - createEntityAdapter, - createSelector, - createSlice, -} from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { ImageCategory, ImageDTO } from 'services/api/types'; -import { dateComparator } from 'common/util/dateComparator'; -import { keyBy } from 'lodash-es'; -import { - imageDeleted, - imageUrlsReceived, - receivedPageOfImages, -} from 'services/api/thunks/image'; - -export const imagesAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at), -}); - -export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; -export const ASSETS_CATEGORIES: ImageCategory[] = [ - 'control', - 'mask', - 'user', - 'other', -]; - -type AdditionaImagesState = { - offset: number; - limit: number; - total: number; - isLoading: boolean; - categories: ImageCategory[]; -}; - -export const initialImagesState = - imagesAdapter.getInitialState({ - offset: 0, - limit: 0, - total: 0, - isLoading: false, - categories: IMAGE_CATEGORIES, - }); - -export type ImagesState = typeof initialImagesState; - -const imagesSlice = createSlice({ - name: 'images', - initialState: initialImagesState, - reducers: { - imageUpserted: (state, action: PayloadAction) => { - imagesAdapter.upsertOne(state, action.payload); - }, - imageUpdatedOne: (state, action: PayloadAction>) => { - imagesAdapter.updateOne(state, action.payload); - }, - imageRemoved: (state, action: PayloadAction) => { - imagesAdapter.removeOne(state, action.payload); - }, - imagesRemoved: (state, action: PayloadAction) => { - imagesAdapter.removeMany(state, action.payload); - }, - imageCategoriesChanged: (state, action: PayloadAction) => { - state.categories = action.payload; - }, - }, - extraReducers: (builder) => { - builder.addCase(receivedPageOfImages.pending, (state) => { - state.isLoading = true; - }); - builder.addCase(receivedPageOfImages.rejected, (state) => { - state.isLoading = false; - }); - builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { - state.isLoading = false; - const { board_id, categories, image_origin, is_intermediate } = - action.meta.arg; - - const { items, offset, limit, total } = action.payload; - imagesAdapter.upsertMany(state, items); - - if (!categories?.includes('general') || board_id) { - // need to skip updating the total images count if the images recieved were for a specific board - // TODO: this doesn't work when on the Asset tab/category... - return; - } - - state.offset = offset; - state.limit = limit; - state.total = total; - }); - builder.addCase(imageDeleted.pending, (state, action) => { - // Image deleted - const { image_name } = action.meta.arg; - imagesAdapter.removeOne(state, image_name); - }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - imagesAdapter.updateOne(state, { - id: image_name, - changes: { image_url, thumbnail_url }, - }); - }); - }, -}); - -export const { - selectAll: selectImagesAll, - selectById: selectImagesById, - selectEntities: selectImagesEntities, - selectIds: selectImagesIds, - selectTotal: selectImagesTotal, -} = imagesAdapter.getSelectors((state) => state.images); - -export const { - imageUpserted, - imageUpdatedOne, - imageRemoved, - imagesRemoved, - imageCategoriesChanged, -} = imagesSlice.actions; - -export default imagesSlice.reducer; - -export const selectFilteredImagesAsArray = createSelector( - (state: RootState) => state, - (state) => { - const { - images: { categories }, - } = state; - - return selectImagesAll(state).filter((i) => - categories.includes(i.image_category) - ); - } -); - -export const selectFilteredImagesAsObject = createSelector( - (state: RootState) => state, - (state) => { - const { - images: { categories }, - } = state; - - return keyBy( - selectImagesAll(state).filter((i) => - categories.includes(i.image_category) - ), - 'image_name' - ); - } -); - -export const selectFilteredImagesIds = createSelector( - (state: RootState) => state, - (state) => { - const { - images: { categories }, - } = state; - - return selectImagesAll(state) - .filter((i) => categories.includes(i.image_category)) - .map((i) => i.image_name); - } -); - -// export const selectImageById = createSelector( -// (state: RootState, imageId) => state, -// (state) => { -// const { -// images: { categories }, -// } = state; - -// return selectImagesAll(state) -// .filter((i) => categories.includes(i.image_category)) -// .map((i) => i.image_name); -// } -// ); diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageButton.tsx b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageButton.tsx new file mode 100644 index 0000000000..dde6d1a517 --- /dev/null +++ b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageButton.tsx @@ -0,0 +1,37 @@ +import { IconButtonProps } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { useTranslation } from 'react-i18next'; +import { FaTrash } from 'react-icons/fa'; + +const deleteImageButtonsSelector = createSelector( + [stateSelector], + ({ system }) => { + const { isProcessing, isConnected } = system; + + return isConnected && !isProcessing; + } +); + +type DeleteImageButtonProps = Omit & { + onClick: () => void; +}; + +export const DeleteImageButton = (props: DeleteImageButtonProps) => { + const { onClick, isDisabled } = props; + const { t } = useTranslation(); + const canDeleteImage = useAppSelector(deleteImageButtonsSelector); + + return ( + } + tooltip={`${t('gallery.deleteImage')} (Del)`} + aria-label={`${t('gallery.deleteImage')} (Del)`} + isDisabled={isDisabled || !canDeleteImage} + colorScheme="error" + /> + ); +}; diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx new file mode 100644 index 0000000000..8306437cc7 --- /dev/null +++ b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx @@ -0,0 +1,124 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogOverlay, + Divider, + Flex, + Text, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice'; + +import { ChangeEvent, memo, useCallback, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import ImageUsageMessage from './ImageUsageMessage'; +import { stateSelector } from 'app/store/store'; +import { + imageDeletionConfirmed, + imageToDeleteCleared, + isModalOpenChanged, + selectImageUsage, +} from '../store/imageDeletionSlice'; + +const selector = createSelector( + [stateSelector, selectImageUsage], + ({ system, config, imageDeletion }, imageUsage) => { + const { shouldConfirmOnDelete } = system; + const { canRestoreDeletedImagesFromBin } = config; + const { imageToDelete, isModalOpen } = imageDeletion; + return { + shouldConfirmOnDelete, + canRestoreDeletedImagesFromBin, + imageToDelete, + imageUsage, + isModalOpen, + }; + }, + defaultSelectorOptions +); + +const DeleteImageModal = () => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { + shouldConfirmOnDelete, + canRestoreDeletedImagesFromBin, + imageToDelete, + imageUsage, + isModalOpen, + } = useAppSelector(selector); + + const handleChangeShouldConfirmOnDelete = useCallback( + (e: ChangeEvent) => + dispatch(setShouldConfirmOnDelete(!e.target.checked)), + [dispatch] + ); + + const handleClose = useCallback(() => { + dispatch(imageToDeleteCleared()); + dispatch(isModalOpenChanged(false)); + }, [dispatch]); + + const handleDelete = useCallback(() => { + if (!imageToDelete || !imageUsage) { + return; + } + dispatch(imageToDeleteCleared()); + dispatch(imageDeletionConfirmed({ imageDTO: imageToDelete, imageUsage })); + }, [dispatch, imageToDelete, imageUsage]); + + const cancelRef = useRef(null); + + return ( + + + + + {t('gallery.deleteImage')} + + + + + + + + {canRestoreDeletedImagesFromBin + ? t('gallery.deleteImageBin') + : t('gallery.deleteImagePermanent')} + + {t('common.areYouSure')} + + + + + + Cancel + + + Delete + + + + + + ); +}; + +export default memo(DeleteImageModal); diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/ImageUsageMessage.tsx b/invokeai/frontend/web/src/features/imageDeletion/components/ImageUsageMessage.tsx new file mode 100644 index 0000000000..9bd4ca5198 --- /dev/null +++ b/invokeai/frontend/web/src/features/imageDeletion/components/ImageUsageMessage.tsx @@ -0,0 +1,33 @@ +import { some } from 'lodash-es'; +import { memo } from 'react'; +import { ImageUsage } from '../store/imageDeletionSlice'; +import { ListItem, Text, UnorderedList } from '@chakra-ui/react'; + +const ImageUsageMessage = (props: { imageUsage?: ImageUsage }) => { + const { imageUsage } = props; + + if (!imageUsage) { + return null; + } + + if (!some(imageUsage)) { + return null; + } + + return ( + <> + This image is currently in use in the following features: + + {imageUsage.isInitialImage && Image to Image} + {imageUsage.isCanvasImage && Unified Canvas} + {imageUsage.isControlNetImage && ControlNet} + {imageUsage.isNodesImage && Node Editor} + + + If you delete this image, those features will immediately be reset. + + + ); +}; + +export default memo(ImageUsageMessage); diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts new file mode 100644 index 0000000000..49630bcdb4 --- /dev/null +++ b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts @@ -0,0 +1,100 @@ +import { + PayloadAction, + createAction, + createSelector, + createSlice, +} from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { some } from 'lodash-es'; +import { ImageDTO } from 'services/api/types'; + +type DeleteImageState = { + imageToDelete: ImageDTO | null; + isModalOpen: boolean; +}; + +export const initialDeleteImageState: DeleteImageState = { + imageToDelete: null, + isModalOpen: false, +}; + +const imageDeletion = createSlice({ + name: 'imageDeletion', + initialState: initialDeleteImageState, + reducers: { + isModalOpenChanged: (state, action: PayloadAction) => { + state.isModalOpen = action.payload; + }, + imageToDeleteSelected: (state, action: PayloadAction) => { + state.imageToDelete = action.payload; + }, + imageToDeleteCleared: (state) => { + state.imageToDelete = null; + state.isModalOpen = false; + }, + }, +}); + +export const { + isModalOpenChanged, + imageToDeleteSelected, + imageToDeleteCleared, +} = imageDeletion.actions; + +export default imageDeletion.reducer; + +export type ImageUsage = { + isInitialImage: boolean; + isCanvasImage: boolean; + isNodesImage: boolean; + isControlNetImage: boolean; +}; + +export const selectImageUsage = createSelector( + [(state: RootState) => state], + ({ imageDeletion, generation, canvas, nodes, controlNet }) => { + const { imageToDelete } = imageDeletion; + + if (!imageToDelete) { + return; + } + + const { image_name } = imageToDelete; + + const isInitialImage = generation.initialImage?.imageName === image_name; + + const isCanvasImage = canvas.layerState.objects.some( + (obj) => obj.kind === 'image' && obj.imageName === image_name + ); + + const isNodesImage = nodes.nodes.some((node) => { + return some( + node.data.inputs, + (input) => + input.type === 'image' && input.value?.image_name === image_name + ); + }); + + const isControlNetImage = some( + controlNet.controlNets, + (c) => + c.controlImage === image_name || c.processedControlImage === image_name + ); + + const imageUsage: ImageUsage = { + isInitialImage, + isCanvasImage, + isNodesImage, + isControlNetImage, + }; + + return imageUsage; + }, + defaultSelectorOptions +); + +export const imageDeletionConfirmed = createAction<{ + imageDTO: ImageDTO; + imageUsage: ImageUsage; +}>('imageDeletion/imageDeletionConfirmed'); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx new file mode 100644 index 0000000000..4ca9700a8c --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -0,0 +1,64 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { FaTrash } from 'react-icons/fa'; +import { + Lora, + loraRemoved, + loraWeightChanged, + loraWeightReset, +} from '../store/loraSlice'; + +type Props = { + lora: Lora; +}; + +const ParamLora = (props: Props) => { + const dispatch = useAppDispatch(); + const { lora } = props; + + const handleChange = useCallback( + (v: number) => { + dispatch(loraWeightChanged({ id: lora.id, weight: v })); + }, + [dispatch, lora.id] + ); + + const handleReset = useCallback(() => { + dispatch(loraWeightReset(lora.id)); + }, [dispatch, lora.id]); + + const handleRemoveLora = useCallback(() => { + dispatch(loraRemoved(lora.id)); + }, [dispatch, lora.id]); + + return ( + + + } + colorScheme="error" + /> + + ); +}; + +export default memo(ParamLora); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx new file mode 100644 index 0000000000..6e69f036df --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -0,0 +1,36 @@ +import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import { size } from 'lodash-es'; +import { memo } from 'react'; +import ParamLoraList from './ParamLoraList'; +import ParamLoraSelect from './ParamLoraSelect'; + +const selector = createSelector( + stateSelector, + (state) => { + const loraCount = size(state.lora.loras); + return { + activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined, + }; + }, + defaultSelectorOptions +); + +const ParamLoraCollapse = () => { + const { activeLabel } = useAppSelector(selector); + + return ( + + + + + + + ); +}; + +export default memo(ParamLoraCollapse); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx new file mode 100644 index 0000000000..89432ac862 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -0,0 +1,24 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { map } from 'lodash-es'; +import ParamLora from './ParamLora'; + +const selector = createSelector( + stateSelector, + ({ lora }) => { + const { loras } = lora; + + return { loras }; + }, + defaultSelectorOptions +); + +const ParamLoraList = () => { + const { loras } = useAppSelector(selector); + + return map(loras, (lora) => ); +}; + +export default ParamLoraList; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx new file mode 100644 index 0000000000..9168814f35 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -0,0 +1,117 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { forwardRef, useCallback, useMemo } from 'react'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { loraAdded } from '../store/loraSlice'; + +type LoraSelectItem = { + label: string; + value: string; + description?: string; +}; + +const selector = createSelector( + stateSelector, + ({ lora }) => ({ + loras: lora.loras, + }), + defaultSelectorOptions +); + +const ParamLoraSelect = () => { + const dispatch = useAppDispatch(); + const { loras } = useAppSelector(selector); + const { data: lorasQueryData } = useGetLoRAModelsQuery(); + + const data = useMemo(() => { + if (!lorasQueryData) { + return []; + } + + const data: LoraSelectItem[] = []; + + forEach(lorasQueryData.entities, (lora, id) => { + if (!lora || Boolean(id in loras)) { + return; + } + + data.push({ + value: id, + label: lora.name, + description: lora.description, + }); + }); + + return data; + }, [loras, lorasQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + const loraEntity = lorasQueryData?.entities[v[0]]; + if (!loraEntity) { + return; + } + v[0] && dispatch(loraAdded(loraEntity)); + }, + [dispatch, lorasQueryData?.entities] + ); + + if (lorasQueryData?.ids.length === 0) { + return ( + + + No LoRAs Loaded + + + ); + } + + return ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + ); +}; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; + +export default ParamLoraSelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts new file mode 100644 index 0000000000..7da6018e58 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -0,0 +1,51 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; + +export type Lora = { + id: string; + name: string; + weight: number; +}; + +export const defaultLoRAConfig: Omit = { + weight: 0.75, +}; + +export type LoraState = { + loras: Record; +}; + +export const intialLoraState: LoraState = { + loras: {}, +}; + +export const loraSlice = createSlice({ + name: 'lora', + initialState: intialLoraState, + reducers: { + loraAdded: (state, action: PayloadAction) => { + const { name, id } = action.payload; + state.loras[id] = { id, name, ...defaultLoRAConfig }; + }, + loraRemoved: (state, action: PayloadAction) => { + const id = action.payload; + delete state.loras[id]; + }, + loraWeightChanged: ( + state, + action: PayloadAction<{ id: string; weight: number }> + ) => { + const { id, weight } = action.payload; + state.loras[id].weight = weight; + }, + loraWeightReset: (state, action: PayloadAction) => { + const id = action.payload; + state.loras[id].weight = defaultLoRAConfig.weight; + }, + }, +}); + +export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } = + loraSlice.actions; + +export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 65b7cfa560..9925a48381 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -3,19 +3,22 @@ import { memo } from 'react'; import { InputFieldTemplate, InputFieldValue } from '../types/types'; import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent'; import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; -import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; -import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; -import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; -import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; -import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; -import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; +import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; +import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; +import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; +import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; +import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; +import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; +import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; -import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; -import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; +import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; +import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; +import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; type InputFieldComponentProps = { nodeId: string; @@ -151,6 +154,26 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'vae_model' && template.type === 'vae_model') { + return ( + + ); + } + + if (type === 'lora_model' && template.type === 'lora_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( { ); } + if (type === 'image_collection' && template.type === 'image_collection') { + return ( + + ); + } + return Unknown field type: {type}; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx index fc3a6377b2..3c3568a6b2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx @@ -30,7 +30,7 @@ const InvocationComponentWrapper = (props: InvocationComponentWrapperProps) => { position: 'relative', borderRadius: 'md', minWidth: NODE_MIN_WIDTH, - boxShadow: props.selected + shadow: props.selected ? `${nodeSelectedOutline}, ${nodeShadow}` : `${nodeShadow}`, }} diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx new file mode 100644 index 0000000000..0ac1f7aa1c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageCollectionInputFieldComponent.tsx @@ -0,0 +1,103 @@ +import { useAppDispatch } from 'app/store/storeHooks'; + +import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + ImageCollectionInputFieldTemplate, + ImageCollectionInputFieldValue, +} from 'features/nodes/types/types'; +import { memo, useCallback } from 'react'; + +import { FieldComponentProps } from './types'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { ImageDTO } from 'services/api/types'; +import { Flex } from '@chakra-ui/react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { uniq, uniqBy } from 'lodash-es'; +import { + NodesMultiImageDropData, + isValidDrop, + useDroppable, +} from 'app/components/ImageDnd/typesafeDnd'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; + +const ImageCollectionInputFieldComponent = ( + props: FieldComponentProps< + ImageCollectionInputFieldValue, + ImageCollectionInputFieldTemplate + > +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + + const handleDrop = useCallback( + ({ image_name }: ImageDTO) => { + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: uniqBy([...(field.value ?? []), { image_name }], 'image_name'), + }) + ); + }, + [dispatch, field.name, field.value, nodeId] + ); + + const droppableData: NodesMultiImageDropData = { + id: `node-${nodeId}-${field.name}`, + actionType: 'SET_MULTI_NODES_IMAGE', + context: { nodeId, fieldName: field.name }, + }; + + const { + isOver, + setNodeRef: setDroppableRef, + active, + over, + } = useDroppable({ + id: `node_${nodeId}`, + data: droppableData, + }); + + const handleReset = useCallback(() => { + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: undefined, + }) + ); + }, [dispatch, field.name, nodeId]); + + return ( + + {field.value?.map(({ image_name }) => ( + + ))} + {isValidDrop(droppableData, active) && } + + ); +}; + +export default memo(ImageCollectionInputFieldComponent); + +type ImageSubFieldProps = { imageName: string }; + +const ImageSubField = (props: ImageSubFieldProps) => { + const { currentData: image } = useGetImageDTOQuery(props.imageName); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index 8d83e8353f..34e403f9cc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -5,14 +5,18 @@ import { ImageInputFieldTemplate, ImageInputFieldValue, } from 'features/nodes/types/types'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; -import { FieldComponentProps } from './types'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { ImageDTO } from 'services/api/types'; import { Flex } from '@chakra-ui/react'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'app/components/ImageDnd/typesafeDnd'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { PostUploadAction } from 'services/api/thunks/image'; +import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( props: FieldComponentProps @@ -22,29 +26,12 @@ const ImageInputFieldComponent = ( const dispatch = useAppDispatch(); const { - currentData: image, + currentData: imageDTO, isLoading, isError, isSuccess, } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); - const handleDrop = useCallback( - ({ image_name }: ImageDTO) => { - if (field.value?.image_name === image_name) { - return; - } - - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: { image_name }, - }) - ); - }, - [dispatch, field.name, field.value, nodeId] - ); - const handleReset = useCallback(() => { dispatch( fieldValueChanged({ @@ -55,6 +42,34 @@ const ImageInputFieldComponent = ( ); }, [dispatch, field.name, nodeId]); + const draggableData = useMemo(() => { + if (imageDTO) { + return { + id: `node-${nodeId}-${field.name}`, + payloadType: 'IMAGE_DTO', + payload: { imageDTO }, + }; + } + }, [field.name, imageDTO, nodeId]); + + const droppableData = useMemo( + () => ({ + id: `node-${nodeId}-${field.name}`, + actionType: 'SET_NODES_IMAGE', + context: { nodeId, fieldName: field.name }, + }), + [field.name, nodeId] + ); + + const postUploadAction = useMemo( + () => ({ + type: 'SET_NODES_IMAGE', + nodeId, + fieldName: field.name, + }), + [nodeId, field.name] + ); + return ( ); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx new file mode 100644 index 0000000000..02cdfd454d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -0,0 +1,102 @@ +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + VaeModelInputFieldTemplate, + VaeModelInputFieldValue, +} from 'features/nodes/types/types'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { forEach, isString } from 'lodash-es'; +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; + +const LoRAModelInputFieldComponent = ( + props: FieldComponentProps< + VaeModelInputFieldValue, + VaeModelInputFieldTemplate + > +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: loraModels } = useGetLoRAModelsQuery(); + + const selectedModel = useMemo( + () => loraModels?.entities[field.value ?? loraModels.ids[0]], + [loraModels?.entities, loraModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!loraModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(loraModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [loraModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && loraModels?.ids.includes(field.value)) { + return; + } + + const firstLora = loraModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, loraModels?.ids]); + + return ( + + ); +}; + +export default memo(LoRAModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 741662655f..ee739e1002 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -6,13 +6,13 @@ import { ModelInputFieldValue, } from 'features/nodes/types/types'; -import { memo, useCallback, useEffect, useMemo } from 'react'; -import { FieldComponentProps } from './types'; -import { forEach, isString } from 'lodash-es'; -import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { forEach, isString } from 'lodash-es'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; const ModelInputFieldComponent = ( props: FieldComponentProps @@ -22,18 +22,16 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: pipelineModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const data = useMemo(() => { - if (!pipelineModels) { + if (!mainModels) { return []; } const data: SelectItem[] = []; - forEach(pipelineModels.entities, (model, id) => { + forEach(mainModels.entities, (model, id) => { if (!model) { return; } @@ -46,11 +44,11 @@ const ModelInputFieldComponent = ( }); return data; - }, [pipelineModels]); + }, [mainModels]); const selectedModel = useMemo( - () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], - [pipelineModels?.entities, pipelineModels?.ids, field.value] + () => mainModels?.entities[field.value ?? mainModels.ids[0]], + [mainModels?.entities, mainModels?.ids, field.value] ); const handleValueChanged = useCallback( @@ -71,18 +69,18 @@ const ModelInputFieldComponent = ( ); useEffect(() => { - if (field.value && pipelineModels?.ids.includes(field.value)) { + if (field.value && mainModels?.ids.includes(field.value)) { return; } - const firstModel = pipelineModels?.ids[0]; + const firstModel = mainModels?.ids[0]; if (!isString(firstModel)) { return; } handleValueChanged(firstModel); - }, [field.value, handleValueChanged, pipelineModels?.ids]); + }, [field.value, handleValueChanged, mainModels?.ids]); return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const selectedModel = useMemo( + () => vaeModels?.entities[field.value ?? vaeModels.ids[0]], + [vaeModels?.entities, vaeModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!vaeModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(vaeModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [vaeModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && vaeModels?.ids.includes(field.value)) { + return; + } + handleValueChanged('auto'); + }, [field.value, handleValueChanged, vaeModels?.ids]); + + return ( + + ); +}; + +export default memo(VaeModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx b/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx index be5e5a943e..740fecc2a4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ui/NodeInvokeButton.tsx @@ -45,6 +45,7 @@ export default function NodeInvokeButton(props: InvokeButton) { {!isReady && ( @@ -71,6 +71,12 @@ export default function NodeInvokeButton(props: InvokeButton) { tooltipProps={{ placement: 'bottom' }} colorScheme="accent" id="invoke-button" + _disabled={{ + background: 'none', + _hover: { + background: 'none', + }, + }} {...rest} /> ) : ( @@ -84,6 +90,12 @@ export default function NodeInvokeButton(props: InvokeButton) { colorScheme="accent" id="invoke-button" fontWeight={700} + _disabled={{ + background: 'none', + _hover: { + background: 'none', + }, + }} {...rest} > Invoke diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index ba217fff5f..4fa69c626b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,8 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; +import { RgbaColor } from 'react-colorful'; import { addEdge, applyEdgeChanges, @@ -11,11 +14,9 @@ import { NodeChange, OnConnectStartParams, } from 'reactflow'; -import { ImageField } from 'services/api/types'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { ImageField } from 'services/api/types'; import { InvocationTemplate, InvocationValue } from '../types/types'; -import { RgbaColor } from 'react-colorful'; -import { RootState } from 'app/store/store'; export type NodesState = { nodes: Node[]; @@ -62,7 +63,14 @@ const nodesSlice = createSlice({ action: PayloadAction<{ nodeId: string; fieldName: string; - value: string | number | boolean | ImageField | RgbaColor | undefined; + value: + | string + | number + | boolean + | ImageField + | RgbaColor + | undefined + | ImageField[]; }> ) => { const { nodeId, fieldName, value } = action.payload; @@ -72,6 +80,35 @@ const nodesSlice = createSlice({ state.nodes[nodeIndex].data.inputs[fieldName].value = value; } }, + imageCollectionFieldValueChanged: ( + state, + action: PayloadAction<{ + nodeId: string; + fieldName: string; + value: ImageField[]; + }> + ) => { + const { nodeId, fieldName, value } = action.payload; + const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); + + if (nodeIndex === -1) { + return; + } + + const currentValue = cloneDeep( + state.nodes[nodeIndex].data.inputs[fieldName].value + ); + + if (!currentValue) { + state.nodes[nodeIndex].data.inputs[fieldName].value = value; + return; + } + + state.nodes[nodeIndex].data.inputs[fieldName].value = uniqBy( + (currentValue as ImageField[]).concat(value), + 'image_name' + ); + }, shouldShowGraphOverlayChanged: (state, action: PayloadAction) => { state.shouldShowGraphOverlay = action.payload; }, @@ -103,6 +140,7 @@ export const { shouldShowGraphOverlayChanged, nodeTemplatesBuilt, nodeEditorReset, + imageCollectionFieldValueChanged, } = nodesSlice.actions; export default nodesSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 83fadb6bcb..5fe780a286 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -10,12 +10,15 @@ export const FIELD_TYPE_MAP: Record = { boolean: 'boolean', enum: 'enum', ImageField: 'image', + image_collection: 'image_collection', LatentsField: 'latents', ConditioningField: 'conditioning', UNetField: 'unet', ClipField: 'clip', VaeField: 'vae', model: 'model', + vae_model: 'vae_model', + lora_model: 'lora_model', array: 'array', item: 'item', ColorField: 'color', @@ -30,9 +33,6 @@ const COLOR_TOKEN_VALUE = 500; const getColorTokenCssVariable = (color: string) => `var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`; -// @ts-ignore -// @ts-ignore -// @ts-ignore export const FIELDS: Record = { integer: { color: 'red', @@ -70,6 +70,12 @@ export const FIELDS: Record = { title: 'Image', description: 'Images may be passed between nodes.', }, + image_collection: { + color: 'purple', + colorCssVar: getColorTokenCssVariable('purple'), + title: 'Image Collection', + description: 'A collection of images.', + }, latents: { color: 'pink', colorCssVar: getColorTokenCssVariable('pink'), @@ -112,6 +118,18 @@ export const FIELDS: Record = { title: 'Model', description: 'Models are models.', }, + vae_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'VAE', + description: 'Models are models.', + }, + lora_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'LoRA', + description: 'Models are models.', + }, array: { color: 'gray', colorCssVar: getColorTokenCssVariable('gray'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 3faf2f9653..3de8cae9ff 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -64,9 +64,12 @@ export type FieldType = | 'vae' | 'control' | 'model' + | 'vae_model' + | 'lora_model' | 'array' | 'item' - | 'color'; + | 'color' + | 'image_collection'; /** * An input field is persisted across reloads as part of the user's local state. @@ -90,9 +93,12 @@ export type InputFieldValue = | ControlInputFieldValue | EnumInputFieldValue | ModelInputFieldValue + | VaeModelInputFieldValue + | LoRAModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue - | ColorInputFieldValue; + | ColorInputFieldValue + | ImageCollectionInputFieldValue; /** * An input field template is generated on each page load from the OpenAPI schema. @@ -114,9 +120,12 @@ export type InputFieldTemplate = | ControlInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate + | VaeModelInputFieldTemplate + | LoRAModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate - | ColorInputFieldTemplate; + | ColorInputFieldTemplate + | ImageCollectionInputFieldTemplate; /** * An output field is persisted across as part of the user's local state. @@ -215,11 +224,26 @@ export type ImageInputFieldValue = FieldValueBase & { value?: ImageField; }; +export type ImageCollectionInputFieldValue = FieldValueBase & { + type: 'image_collection'; + value?: ImageField[]; +}; + export type ModelInputFieldValue = FieldValueBase & { type: 'model'; value?: string; }; +export type VaeModelInputFieldValue = FieldValueBase & { + type: 'vae_model'; + value?: string; +}; + +export type LoRAModelInputFieldValue = FieldValueBase & { + type: 'lora_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -282,6 +306,11 @@ export type ImageInputFieldTemplate = InputFieldTemplateBase & { type: 'image'; }; +export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & { + default: ImageField[]; + type: 'image_collection'; +}; + export type LatentsInputFieldTemplate = InputFieldTemplateBase & { default: string; type: 'latents'; @@ -292,6 +321,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { type: 'conditioning'; }; +export type UNetInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'unet'; +}; + +export type ClipInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'clip'; +}; + +export type VaeInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'vae'; +}; + export type ControlInputFieldTemplate = InputFieldTemplateBase & { default: undefined; type: 'control'; @@ -309,6 +353,16 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & { type: 'model'; }; +export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'vae_model'; +}; + +export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'lora_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index 11ceb23763..5c4d67ebd3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { filter } from 'lodash-es'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { NonNullableGraph } from '../types/types'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; @@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = ( ): void => { const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; - const validControlNets = filter( - controlNets, - (c) => - c.isEnabled && - (Boolean(c.processedControlImage) || - (c.processorType === 'none' && Boolean(c.controlImage))) - ); + const validControlNets = getValidControlNets(controlNets); if (isControlNetEnabled && Boolean(validControlNets.length)) { if (validControlNets.length > 1) { diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index f1ad731d32..1c2dbc0c3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -3,26 +3,29 @@ import { OpenAPIV3 } from 'openapi-types'; import { FIELD_TYPE_MAP } from '../types/constants'; import { isSchemaObject } from '../types/typeGuards'; import { - BooleanInputFieldTemplate, - EnumInputFieldTemplate, - FloatInputFieldTemplate, - ImageInputFieldTemplate, - IntegerInputFieldTemplate, - LatentsInputFieldTemplate, - ConditioningInputFieldTemplate, - UNetInputFieldTemplate, - ClipInputFieldTemplate, - VaeInputFieldTemplate, - ControlInputFieldTemplate, - StringInputFieldTemplate, - ModelInputFieldTemplate, ArrayInputFieldTemplate, - ItemInputFieldTemplate, + BooleanInputFieldTemplate, + ClipInputFieldTemplate, ColorInputFieldTemplate, - InputFieldTemplateBase, - OutputFieldTemplate, - TypeHints, + ConditioningInputFieldTemplate, + ControlInputFieldTemplate, + EnumInputFieldTemplate, FieldType, + FloatInputFieldTemplate, + ImageCollectionInputFieldTemplate, + ImageInputFieldTemplate, + InputFieldTemplateBase, + IntegerInputFieldTemplate, + ItemInputFieldTemplate, + LatentsInputFieldTemplate, + LoRAModelInputFieldTemplate, + ModelInputFieldTemplate, + OutputFieldTemplate, + StringInputFieldTemplate, + TypeHints, + UNetInputFieldTemplate, + VaeInputFieldTemplate, + VaeModelInputFieldTemplate, } from '../types/types'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -174,6 +177,36 @@ const buildModelInputFieldTemplate = ({ return template; }; +const buildVaeModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): VaeModelInputFieldTemplate => { + const template: VaeModelInputFieldTemplate = { + ...baseField, + type: 'vae_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildLoRAModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { + const template: LoRAModelInputFieldTemplate = { + ...baseField, + type: 'lora_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -189,6 +222,21 @@ const buildImageInputFieldTemplate = ({ return template; }; +const buildImageCollectionInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ImageCollectionInputFieldTemplate => { + const template: ImageCollectionInputFieldTemplate = { + ...baseField, + type: 'image_collection', + inputRequirement: 'always', + inputKind: 'any', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildLatentsInputFieldTemplate = ({ schemaObject, baseField, @@ -400,6 +448,10 @@ export const buildInputFieldTemplate = ( if (['image'].includes(fieldType)) { return buildImageInputFieldTemplate({ schemaObject, baseField }); } + + if (['image_collection'].includes(fieldType)) { + return buildImageCollectionInputFieldTemplate({ schemaObject, baseField }); + } if (['latents'].includes(fieldType)) { return buildLatentsInputFieldTemplate({ schemaObject, baseField }); } @@ -421,6 +473,12 @@ export const buildInputFieldTemplate = ( if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } + if (['vae_model'].includes(fieldType)) { + return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); + } + if (['lora_model'].includes(fieldType)) { + return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 1703c45331..950038b691 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -44,6 +44,10 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'image_collection') { + fieldValue.value = []; + } + if (template.type === 'latents') { fieldValue.value = undefined; } @@ -71,6 +75,14 @@ export const buildInputFieldValue = ( if (template.type === 'model') { fieldValue.value = undefined; } + + if (template.type === 'vae_model') { + fieldValue.value = undefined; + } + + if (template.type === 'lora_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts new file mode 100644 index 0000000000..9712ef4d5f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -0,0 +1,148 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { LoraLoaderInvocation } from 'services/api/types'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { + LORA_LOADER, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, +} from './constants'; + +export const addLoRAsToGraph = ( + graph: NonNullableGraph, + state: RootState, + baseNodeId: string +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + + if (loraCount > 0) { + // remove any existing connections from main model loader, we need to insert the lora nodes + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === MAIN_MODEL_LOADER && + ['unet', 'clip'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { id, name, weight } = lora; + const loraField = modelIdToLoRAModelField(id); + const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( + '.', + '_' + )}`; + + const loraLoaderNode: LoraLoaderInvocation = { + type: 'lora_loader', + id: currentLoraNodeId, + lora: loraField, + weight, + }; + + graph.nodes[currentLoraNodeId] = loraLoaderNode; + + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts new file mode 100644 index 0000000000..4dd3d644ee --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -0,0 +1,68 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; +import { + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, + INPAINT, + INPAINT_GRAPH, + LATENTS_TO_IMAGE, + MAIN_MODEL_LOADER, + TEXT_TO_IMAGE_GRAPH, + VAE_LOADER, +} from './constants'; + +export const addVAEToGraph = ( + graph: NonNullableGraph, + state: RootState +): void => { + const { vae: vaeId } = state.generation; + const vae_model = modelIdToVAEModelField(vaeId); + + if (vaeId !== 'auto') { + graph.nodes[VAE_LOADER] = { + type: 'vae_loader', + id: VAE_LOADER, + vae_model, + }; + } + + if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }); + } + + if (graph.id === IMAGE_TO_IMAGE_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'vae', + }, + }); + } + + if (graph.id === INPAINT_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: INPAINT, + field: 'vae', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 49bab291f7..1843efef84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -1,31 +1,27 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageResizeInvocation, ImageToLatentsInvocation, - RandomIntInvocation, - RangeOfSizeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { - ITERATE, + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + LATENTS_TO_LATENTS, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - RANDOM_INT, - RANGE_OF_SIZE, - IMAGE_TO_IMAGE_GRAPH, - IMAGE_TO_LATENTS, - LATENTS_TO_LATENTS, RESIZE, } from './constants'; -import { set } from 'lodash-es'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -52,7 +48,7 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -81,9 +77,9 @@ export const buildCanvasImageToImageGraph = ( type: 'noise', id: NOISE, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -110,7 +106,7 @@ export const buildCanvasImageToImageGraph = ( edges: [ { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -120,7 +116,7 @@ export const buildCanvasImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -128,16 +124,6 @@ export const buildCanvasImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -170,17 +156,7 @@ export const buildCanvasImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -277,6 +253,11 @@ export const buildCanvasImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 74bd12a742..c4f9415067 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -1,23 +1,25 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, InpaintInvocation, RandomIntInvocation, RangeOfSizeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { + INPAINT, + INPAINT_GRAPH, ITERATE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING, RANDOM_INT, RANGE_OF_SIZE, - INPAINT_GRAPH, - INPAINT, } from './constants'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -55,7 +57,7 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); const graph: NonNullableGraph = { id: INPAINT_GRAPH, @@ -101,9 +103,9 @@ export const buildCanvasInpaintGraph = ( id: NEGATIVE_CONDITIONING, prompt: negativePrompt, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [RANGE_OF_SIZE]: { @@ -142,7 +144,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -152,7 +154,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -162,7 +164,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -170,16 +172,6 @@ export const buildCanvasInpaintGraph = ( field: 'unet', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: INPAINT, - field: 'vae', - }, - }, { source: { node_id: RANGE_OF_SIZE, @@ -203,6 +195,11 @@ export const buildCanvasInpaintGraph = ( ], }; + addLoRAsToGraph(graph, state, INPAINT); + + // Add VAE + addVAEToGraph(graph, state); + // handle seed if (shouldRandomizeSeed) { // Random int node to generate the starting seed diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index b15b2cd192..976ea4fd01 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -1,21 +1,19 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { - ITERATE, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - RANDOM_INT, - RANGE_OF_SIZE, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; /** * Builds the Canvas tab's Text to Image graph. @@ -38,7 +36,7 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -76,9 +74,9 @@ export const buildCanvasTextToImageGraph = ( scheduler, steps, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -109,7 +107,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -119,7 +117,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -129,7 +127,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -147,16 +145,6 @@ export const buildCanvasTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -170,6 +158,11 @@ export const buildCanvasTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 15d5a431a2..fe6d1292e4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -1,24 +1,30 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { + ImageCollectionInvocation, ImageResizeInvocation, ImageToLatentsInvocation, + IterateInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { + IMAGE_COLLECTION, + IMAGE_COLLECTION_ITERATE, + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + LATENTS_TO_LATENTS, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - IMAGE_TO_IMAGE_GRAPH, - IMAGE_TO_LATENTS, - LATENTS_TO_LATENTS, RESIZE, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -42,6 +48,15 @@ export const buildLinearImageToImageGraph = ( height, } = state.generation; + const { + isEnabled: isBatchEnabled, + imageNames: batchImageNames, + asInitialImage, + } = state.batch; + + const shouldBatch = + isBatchEnabled && batchImageNames.length > 0 && asInitialImage; + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -51,12 +66,12 @@ export const buildLinearImageToImageGraph = ( * the `fit` param. These are added to the graph at the end. */ - if (!initialImage) { + if (!initialImage && !shouldBatch) { moduleLog.error('No initial image found in state'); throw new Error('No initial image found in state'); } - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { @@ -76,9 +91,9 @@ export const buildLinearImageToImageGraph = ( type: 'noise', id: NOISE, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -105,7 +120,7 @@ export const buildLinearImageToImageGraph = ( edges: [ { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -115,7 +130,7 @@ export const buildLinearImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -123,16 +138,6 @@ export const buildLinearImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -163,19 +168,10 @@ export const buildLinearImageToImageGraph = ( field: 'noise', }, }, + { source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -275,6 +271,46 @@ export const buildLinearImageToImageGraph = ( }); } + if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) { + // we are going to connect an iterate up to the init image + delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image; + + const imageCollection: ImageCollectionInvocation = { + id: IMAGE_COLLECTION, + type: 'image_collection', + images: batchImageNames.map((image_name) => ({ image_name })), + }; + + const imageCollectionIterate: IterateInvocation = { + id: IMAGE_COLLECTION_ITERATE, + type: 'iterate', + }; + + graph.nodes[IMAGE_COLLECTION] = imageCollection; + graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate; + + graph.edges.push({ + source: { node_id: IMAGE_COLLECTION, field: 'collection' }, + destination: { + node_id: IMAGE_COLLECTION_ITERATE, + field: 'collection', + }, + }); + + graph.edges.push({ + source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'image', + }, + }); + } + + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index 216c5c8c67..04dccf4983 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,17 +1,19 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; export const buildLinearTextToImageGraph = ( state: RootState @@ -27,7 +29,7 @@ export const buildLinearTextToImageGraph = ( height, } = state.generation; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -65,9 +67,9 @@ export const buildLinearTextToImageGraph = ( scheduler, steps, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -98,7 +100,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -108,7 +110,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -118,7 +120,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -136,16 +138,6 @@ export const buildLinearTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -159,6 +151,11 @@ export const buildLinearTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + + // Add Custom VAE Support + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 091899a21a..12a567b009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,10 +1,12 @@ -import { Graph } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; +import { cloneDeep, omit, reduce } from 'lodash-es'; +import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; +import { v4 as uuidv4 } from 'uuid'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; /** * We need to do special handling for some fields @@ -27,7 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => { if (field.type === 'model') { if (field.value) { - return modelIdToPipelineModelField(field.value); + return modelIdToMainModelField(field.value); + } + } + + if (field.type === 'vae_model') { + if (field.value) { + return modelIdToVAEModelField(field.value); + } + } + + if (field.type === 'lora_model') { + if (field.value) { + return modelIdToLoRAModelField(field.value); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index d6ab33a6ea..7aace48def 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,13 +7,17 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader'; +export const MAIN_MODEL_LOADER = 'main_model_loader'; +export const VAE_LOADER = 'vae_loader'; +export const LORA_LOADER = 'lora_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; export const INPAINT = 'inpaint'; export const CONTROL_NET_COLLECT = 'control_net_collect'; export const DYNAMIC_PROMPT = 'dynamic_prompt'; +export const IMAGE_COLLECTION = 'image_collection'; +export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate'; // friendly graph ids export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts new file mode 100644 index 0000000000..052b58484b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts @@ -0,0 +1,12 @@ +import { BaseModelType, LoRAModelField } from 'services/api/types'; + +export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const field: LoRAModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts new file mode 100644 index 0000000000..6bb0f776b2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts @@ -0,0 +1,16 @@ +import { BaseModelType, MainModelField } from 'services/api/types'; + +/** + * Crudely converts a model id to a main model field + * TODO: Make better + */ +export const modelIdToMainModelField = (modelId: string): MainModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: MainModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts deleted file mode 100644 index 0941255181..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { BaseModelType, PipelineModelField } from 'services/api/types'; - -/** - * Crudely converts a model id to a pipeline model field - * TODO: Make better - */ -export const modelIdToPipelineModelField = ( - modelId: string -): PipelineModelField => { - const [base_model, model_type, model_name] = modelId.split('/'); - - const field: PipelineModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts new file mode 100644 index 0000000000..0cb608a936 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts @@ -0,0 +1,16 @@ +import { BaseModelType, VAEModelField } from 'services/api/types'; + +/** + * Crudely converts a model id to a main model field + * TODO: Make better + */ +export const modelIdToVAEModelField = (modelId: string): VAEModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: VAEModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx index fea0d8330a..b9cc8511aa 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx @@ -1,20 +1,15 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; +import { Flex } from '@chakra-ui/react'; import IAICollapse from 'common/components/IAICollapse'; import { memo } from 'react'; -import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; +import { useTranslation } from 'react-i18next'; import ParamBoundingBoxHeight from './ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; const ParamBoundingBoxCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx index ed01da9876..a531eba57f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx @@ -1,4 +1,4 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; +import { Flex } from '@chakra-ui/react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse'; import ParamInfillMethod from './ParamInfillMethod'; import ParamInfillTilesize from './ParamInfillTilesize'; import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing'; -import ParamScaledWidth from './ParamScaledWidth'; import ParamScaledHeight from './ParamScaledHeight'; +import ParamScaledWidth from './ParamScaledWidth'; const ParamInfillCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx index 992e8b6d02..88d839fa15 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx @@ -1,22 +1,16 @@ +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import ParamSeamBlur from './ParamSeamBlur'; import ParamSeamSize from './ParamSeamSize'; import ParamSeamSteps from './ParamSeamSteps'; import ParamSeamStrength from './ParamSeamStrength'; -import { useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; const ParamSeamCorrectionCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 06c6108dcb..59bf7542eb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -1,41 +1,45 @@ import { Divider, Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { Fragment, memo, useCallback } from 'react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAICollapse from 'common/components/IAICollapse'; +import ControlNet from 'features/controlNet/components/ControlNet'; +import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { controlNetAdded, controlNetSelector, - isControlNetEnabledToggled, } from 'features/controlNet/store/controlNetSlice'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; -import { v4 as uuidv4 } from 'uuid'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import IAIButton from 'common/components/IAIButton'; -import ControlNet from 'features/controlNet/components/ControlNet'; +import { map } from 'lodash-es'; +import { Fragment, memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( controlNetSelector, (controlNet) => { const { controlNets, isEnabled } = controlNet; - return { controlNetsArray: map(controlNets), isEnabled }; + const validControlNets = getValidControlNets(controlNets); + + const activeLabel = + isEnabled && validControlNets.length > 0 + ? `${validControlNets.length} Active` + : undefined; + + return { controlNetsArray: map(controlNets), activeLabel }; }, defaultSelectorOptions ); const ParamControlNetCollapse = () => { const { t } = useTranslation(); - const { controlNetsArray, isEnabled } = useAppSelector(selector); + const { controlNetsArray, activeLabel } = useAppSelector(selector); const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const dispatch = useAppDispatch(); - const handleClickControlNetToggle = useCallback(() => { - dispatch(isControlNetEnabledToggled()); - }, [dispatch]); - const handleClickedAddControlNet = useCallback(() => { dispatch(controlNetAdded({ controlNetId: uuidv4() })); }, [dispatch]); @@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => { } return ( - + + {controlNetsArray.map((c, i) => ( {i > 0 && } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx index 111e3d3ae8..d32ff960d5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; @@ -27,7 +28,8 @@ const selector = createSelector( shouldUseSliders, shift, }; - } + }, + defaultSelectorOptions ); const ParamCFGScale = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx index 9501c8b475..6939ede424 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setHeight } from 'features/parameters/store/generationSlice'; @@ -25,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamHeightProps = Omit< diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx index a8cdabc8c9..1e203a1e45 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx @@ -1,37 +1,38 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setIterations } from 'features/parameters/store/generationSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { hotkeysSelector } from 'features/ui/store/hotkeysSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector([stateSelector], (state) => { - const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = - state.config.sd.iterations; - const { iterations } = state.generation; - const { shouldUseSliders } = state.ui; - const isDisabled = - state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; +const selector = createSelector( + [stateSelector], + (state) => { + const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = + state.config.sd.iterations; + const { iterations } = state.generation; + const { shouldUseSliders } = state.ui; + const isDisabled = + state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; - const step = state.hotkeys.shift ? fineStep : coarseStep; + const step = state.hotkeys.shift ? fineStep : coarseStep; - return { - iterations, - initial, - min, - sliderMax, - inputMax, - step, - shouldUseSliders, - isDisabled, - }; -}); + return { + iterations, + initial, + min, + sliderMax, + inputMax, + step, + shouldUseSliders, + isDisabled, + }; + }, + defaultSelectorOptions +); const ParamIterations = () => { const { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx similarity index 60% rename from invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx rename to invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx index 5092893eed..1c704a86ef 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx @@ -1,19 +1,19 @@ import { Box, Flex } from '@chakra-ui/react'; import ModelSelect from 'features/system/components/ModelSelect'; +import VAESelect from 'features/system/components/VAESelect'; import { memo } from 'react'; -import ParamScheduler from './ParamScheduler'; -const ParamSchedulerAndModel = () => { +const ParamModelandVAE = () => { return ( - - - + + + ); }; -export default memo(ParamSchedulerAndModel); +export default memo(ParamModelandVAE); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx index 589b751d6b..3e5320ad47 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx @@ -1,29 +1,107 @@ -import { FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import type { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; import { setNegativePrompt } from 'features/parameters/store/generationSlice'; +import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; +import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; const ParamNegativeConditioning = () => { const negativePrompt = useAppSelector( (state: RootState) => state.generation.negativePrompt ); - + const promptRef = useRef(null); + const { isOpen, onClose, onOpen } = useDisclosure(); const dispatch = useAppDispatch(); const { t } = useTranslation(); + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setNegativePrompt(e.target.value)); + }, + [dispatch] + ); + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === '<') { + onOpen(); + } + }, + [onOpen] + ); + + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = negativePrompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += negativePrompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setNegativePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, negativePrompt] + ); + return ( - dispatch(setNegativePrompt(e.target.value))} - placeholder={t('parameters.negativePromptPlaceholder')} - fontSize="sm" - minH={16} - /> + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx index f42942a84b..cbff29e89c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx @@ -1,4 +1,4 @@ -import { Box, FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; @@ -11,12 +11,15 @@ import { } from 'features/parameters/store/generationSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { isEqual } from 'lodash-es'; -import { useHotkeys } from 'react-hotkeys-hook'; -import { useTranslation } from 'react-i18next'; import { userInvoked } from 'app/store/actions'; import IAITextarea from 'common/components/IAITextarea'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; +import { isEqual } from 'lodash-es'; +import { flushSync } from 'react-dom'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; const promptInputSelector = createSelector( [(state: RootState) => state.generation, activeTabNameSelector], @@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => { const dispatch = useAppDispatch(); const { prompt, activeTabName } = useAppSelector(promptInputSelector); const isReady = useIsReadyToInvoke(); - const promptRef = useRef(null); - + const { isOpen, onClose, onOpen } = useDisclosure(); const { t } = useTranslation(); - - const handleChangePrompt = (e: ChangeEvent) => { - dispatch(setPositivePrompt(e.target.value)); - }; + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setPositivePrompt(e.target.value)); + }, + [dispatch] + ); useHotkeys( 'alt+a', @@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => { [] ); + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = prompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += prompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setPositivePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger + promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, prompt] + ); + const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && e.shiftKey === false && isReady) { @@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => { dispatch(clampSymmetrySteps()); dispatch(userInvoked(activeTabName)); } + if (e.key === '<') { + onOpen(); + } }, - [dispatch, activeTabName, isReady] + [isReady, dispatch, activeTabName, onOpen] ); + // const handleSelect = (e: MouseEvent) => { + // const target = e.target as HTMLTextAreaElement; + // setCaret({ start: target.selectionStart, end: target.selectionEnd }); + // }; + return ( - + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx index f43cdd425b..d939113c7c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; @@ -33,7 +34,8 @@ const selector = createSelector( step, shouldUseSliders, }; - } + }, + defaultSelectorOptions ); const ParamSteps = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx index b7d63038d1..b4121184b5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { IAIFullSliderProps } from 'common/components/IAISlider'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setWidth } from 'features/parameters/store/generationSlice'; import { configSelector } from 'features/system/store/configSelectors'; @@ -26,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamWidthProps = Omit; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx index b4b077ad6c..fa8606d610 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx @@ -1,37 +1,39 @@ import { Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; -import { ParamHiresStrength } from './ParamHiresStrength'; -import { setHiresFix } from 'features/parameters/store/postprocessingSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { ParamHiresStrength } from './ParamHiresStrength'; +import { ParamHiresToggle } from './ParamHiresToggle'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamHiresCollapse = () => { const { t } = useTranslation(); - const hiresFix = useAppSelector( - (state: RootState) => state.postprocessing.hiresFix - ); + const { activeLabel } = useAppSelector(selector); const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setHiresFix(!hiresFix)); - if (!isHiresEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx index 0fc600e9e8..f8e6f22aa4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx @@ -23,7 +23,6 @@ export const ParamHiresToggle = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx new file mode 100644 index 0000000000..7951df31a7 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImage.tsx @@ -0,0 +1,76 @@ +import { Flex, Icon, Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useMemo } from 'react'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { FaImage } from 'react-icons/fa'; +import { stateSelector } from 'app/store/store'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'app/components/ImageDnd/typesafeDnd'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; + +const selector = createSelector( + [stateSelector], + (state) => { + const { initialImage } = state.generation; + const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch; + return { + initialImage, + useBatchAsInitialImage, + isResetButtonDisabled: useBatchAsInitialImage + ? imageNames.length === 0 + : !initialImage, + }; + }, + defaultSelectorOptions +); + +const InitialImage = () => { + const { initialImage } = useAppSelector(selector); + + const { + currentData: imageDTO, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken); + + const draggableData = useMemo(() => { + if (imageDTO) { + return { + id: 'initial-image', + payloadType: 'IMAGE_DTO', + payload: { imageDTO }, + }; + } + }, [imageDTO]); + + const droppableData = useMemo( + () => ({ + id: 'initial-image', + actionType: 'SET_INITIAL_IMAGE', + }), + [] + ); + + return ( + + } + /> + ); +}; + +export default InitialImage; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImageDisplay.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImageDisplay.tsx index 19eb45a0a9..c08f714488 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImageDisplay.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImageDisplay.tsx @@ -1,34 +1,154 @@ -import { Flex } from '@chakra-ui/react'; -import InitialImagePreview from './InitialImagePreview'; +import { Flex, Spacer, Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { useCallback, useMemo } from 'react'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { FaLayerGroup, FaUndo, FaUpload } from 'react-icons/fa'; +import useImageUploader from 'common/hooks/useImageUploader'; +import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import IAIButton from 'common/components/IAIButton'; +import { stateSelector } from 'app/store/store'; +import { + asInitialImageToggled, + batchReset, +} from 'features/batch/store/batchSlice'; +import BatchImageContainer from 'features/batch/components/BatchImageContainer'; +import { PostUploadAction } from 'services/api/thunks/image'; +import InitialImage from './InitialImage'; + +const selector = createSelector( + [stateSelector], + (state) => { + const { initialImage } = state.generation; + const { asInitialImage: useBatchAsInitialImage, imageNames } = state.batch; + return { + initialImage, + useBatchAsInitialImage, + isResetButtonDisabled: useBatchAsInitialImage + ? imageNames.length === 0 + : !initialImage, + }; + }, + defaultSelectorOptions +); const InitialImageDisplay = () => { + const { initialImage, useBatchAsInitialImage, isResetButtonDisabled } = + useAppSelector(selector); + const dispatch = useAppDispatch(); + const { openUploader } = useImageUploader(); + + const { + currentData: imageDTO, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken); + + const postUploadAction = useMemo( + () => + useBatchAsInitialImage + ? { type: 'ADD_TO_BATCH' } + : { type: 'SET_INITIAL_IMAGE' }, + [useBatchAsInitialImage] + ); + + const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ + postUploadAction, + }); + + const handleReset = useCallback(() => { + if (useBatchAsInitialImage) { + dispatch(batchReset()); + } else { + dispatch(clearInitialImage()); + } + }, [dispatch, useBatchAsInitialImage]); + + const handleUpload = useCallback(() => { + openUploader(); + }, [openUploader]); + + const handleClickUseBatch = useCallback(() => { + dispatch(asInitialImageToggled()); + }, [dispatch]); + return ( - + + Initial Image + + + {/* } + isChecked={useBatchAsInitialImage} + onClick={handleClickUseBatch} + > + {useBatchAsInitialImage ? 'Batch' : 'Single'} + */} + } + onClick={handleUpload} + {...getUploadButtonProps()} + /> + } + onClick={handleReset} + isDisabled={isResetButtonDisabled} + /> + + {/* {useBatchAsInitialImage ? : } */} + ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx deleted file mode 100644 index 2a05eee9b4..0000000000 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ /dev/null @@ -1,126 +0,0 @@ -import { Flex, Spacer, Text } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { - clearInitialImage, - initialImageChanged, -} from 'features/parameters/store/generationSlice'; -import { useCallback } from 'react'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { ImageDTO } from 'services/api/types'; -import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { skipToken } from '@reduxjs/toolkit/dist/query'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { FaUndo, FaUpload } from 'react-icons/fa'; -import useImageUploader from 'common/hooks/useImageUploader'; -import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; - -const selector = createSelector( - [generationSelector], - (generation) => { - const { initialImage } = generation; - return { - initialImage, - }; - }, - defaultSelectorOptions -); - -const InitialImagePreview = () => { - const { initialImage } = useAppSelector(selector); - const dispatch = useAppDispatch(); - const { openUploader } = useImageUploader(); - - const { - currentData: image, - isLoading, - isError, - isSuccess, - } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken); - - const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ - postUploadAction: { type: 'SET_INITIAL_IMAGE' }, - }); - - const handleDrop = useCallback( - (droppedImage: ImageDTO) => { - if (droppedImage.image_name === initialImage?.imageName) { - return; - } - dispatch(initialImageChanged(droppedImage)); - }, - [dispatch, initialImage] - ); - - const handleReset = useCallback(() => { - dispatch(clearInitialImage()); - }, [dispatch]); - - const handleUpload = useCallback(() => { - openUploader(); - }, [openUploader]); - - return ( - - - - Initial Image - - - } - onClick={handleUpload} - {...getUploadButtonProps()} - /> - } - onClick={handleReset} - isDisabled={!initialImage} - /> - - } - isUploadDisabled={true} - fitContainer - /> - - - ); -}; - -export default InitialImagePreview; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx index adb76d8da0..4dea1dad4f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx @@ -1,27 +1,33 @@ -import { useTranslation } from 'react-i18next'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import ParamPerlinNoise from './ParamPerlinNoise'; -import ParamNoiseThreshold from './ParamNoiseThreshold'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamNoiseThreshold from './ParamNoiseThreshold'; +import { ParamNoiseToggle } from './ParamNoiseToggle'; +import ParamPerlinNoise from './ParamPerlinNoise'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings } = state.generation; + return { + activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined, + }; + }, + defaultSelectorOptions +); const ParamNoiseCollapse = () => { const { t } = useTranslation(); const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; - const shouldUseNoiseSettings = useAppSelector( - (state: RootState) => state.generation.shouldUseNoiseSettings - ); - - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings)); + const { activeLabel } = useAppSelector(selector); if (!isNoiseEnabled) { return null; @@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => { return ( + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx index e339734992..3abb7532b4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx @@ -1,18 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setThreshold } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, threshold } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + threshold, + }; + }, + defaultSelectorOptions +); + export default function ParamNoiseThreshold() { const dispatch = useAppDispatch(); - const threshold = useAppSelector( - (state: RootState) => state.generation.threshold - ); + const { threshold, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + const dispatch = useAppDispatch(); + + const shouldUseNoiseSettings = useAppSelector( + (state: RootState) => state.generation.shouldUseNoiseSettings + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldUseNoiseSettings(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx index ad710eae54..afd676223c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx @@ -1,16 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setPerlin } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, perlin } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + perlin, + }; + }, + defaultSelectorOptions +); + export default function ParamPerlinNoise() { const dispatch = useAppDispatch(); - const perlin = useAppSelector((state: RootState) => state.generation.perlin); + const { perlin, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + if (seamlessXAxis && seamlessYAxis) { + return 'X & Y'; + } + + if (seamlessXAxis) { + return 'X'; + } + + if (seamlessYAxis) { + return 'Y'; + } +}; const selector = createSelector( generationSelector, (generation) => { - const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation; + const { seamlessXAxis, seamlessYAxis } = generation; - return { shouldUseSeamless, seamlessXAxis, seamlessYAxis }; + const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis); + return { activeLabel }; }, defaultSelectorOptions ); const ParamSeamlessCollapse = () => { const { t } = useTranslation(); - const { shouldUseSeamless } = useAppSelector(selector); + const { activeLabel } = useAppSelector(selector); const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless)); - if (!isSeamlessEnabled) { return null; } @@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedRandomize.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedRandomize.tsx index 6b1dd46780..f30d9215e8 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedRandomize.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedRandomize.tsx @@ -1,10 +1,8 @@ import { ChangeEvent, memo } from 'react'; - import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setShouldRandomizeSeed } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; -import { FormControl, FormLabel, Switch, Tooltip } from '@chakra-ui/react'; import IAISwitch from 'common/components/IAISwitch'; const ParamSeedRandomize = () => { @@ -25,32 +23,6 @@ const ParamSeedRandomize = () => { onChange={handleChangeShouldRandomizeSeed} /> ); - - return ( - - - {t('parameters.randomizeSeed')} - - - - ); }; export default memo(ParamSeedRandomize); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedShuffle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedShuffle.tsx index 6442e34268..e71e2c36c0 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedShuffle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedShuffle.tsx @@ -1,8 +1,6 @@ -import { Box } from '@chakra-ui/react'; import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; import IAIIconButton from 'common/components/IAIIconButton'; import randomInt from 'common/util/randomInt'; import { setSeed } from 'features/parameters/store/generationSlice'; @@ -29,16 +27,4 @@ export default function ParamSeedShuffle() { icon={} /> ); - - return ( - - {t('parameters.shuffle')} - - ); } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx index 59bdb39be1..f2ddd19768 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx @@ -1,39 +1,39 @@ -import { memo } from 'react'; import { Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import ParamSymmetryHorizontal from './ParamSymmetryHorizontal'; import ParamSymmetryVertical from './ParamSymmetryVertical'; -import { useTranslation } from 'react-i18next'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { useTranslation } from 'react-i18next'; +import ParamSymmetryToggle from './ParamSymmetryToggle'; + +const selector = createSelector( + stateSelector, + (state) => ({ + activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined, + }), + defaultSelectorOptions +); const ParamSymmetryCollapse = () => { const { t } = useTranslation(); - const shouldUseSymmetry = useAppSelector( - (state: RootState) => state.generation.shouldUseSymmetry - ); + const { activeLabel } = useAppSelector(selector); const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry)); - if (!isSymmetryEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx index 7cc17c045e..59386ff526 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx @@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() { return ( dispatch(setShouldUseSymmetry(e.target.checked))} /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx index 1564bd64e5..3cdfc3a06b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx @@ -1,39 +1,42 @@ -import ParamVariationWeights from './ParamVariationWeights'; -import ParamVariationAmount from './ParamVariationAmount'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamVariationAmount from './ParamVariationAmount'; +import { ParamVariationToggle } from './ParamVariationToggle'; +import ParamVariationWeights from './ParamVariationWeights'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.generation.shouldGenerateVariations + ? 'Enabled' + : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamVariationCollapse = () => { const { t } = useTranslation(); - const shouldGenerateVariations = useAppSelector( - (state: RootState) => state.generation.shouldGenerateVariations - ); + const { activeLabel } = useAppSelector(selector); const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldGenerateVariations(!shouldGenerateVariations)); - if (!isVariationEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx new file mode 100644 index 0000000000..1c05468de0 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx @@ -0,0 +1,27 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; +import { ChangeEvent } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const ParamVariationToggle = () => { + const dispatch = useAppDispatch(); + + const shouldGenerateVariations = useAppSelector( + (state: RootState) => state.generation.shouldGenerateVariations + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldGenerateVariations(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx b/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx index 6f82562e48..e2338e2575 100644 --- a/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/ProcessButtons/InvokeButton.tsx @@ -1,4 +1,4 @@ -import { Box } from '@chakra-ui/react'; +import { Box, ChakraProps } from '@chakra-ui/react'; import { userInvoked } from 'app/store/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton, { IAIButtonProps } from 'common/components/IAIButton'; @@ -14,6 +14,16 @@ import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { FaPlay } from 'react-icons/fa'; +const IN_PROGRESS_STYLES: ChakraProps['sx'] = { + _disabled: { + bg: 'none', + cursor: 'not-allowed', + _hover: { + bg: 'none', + }, + }, +}; + interface InvokeButton extends Omit { iconButton?: boolean; @@ -24,6 +34,7 @@ export default function InvokeButton(props: InvokeButton) { const dispatch = useAppDispatch(); const isReady = useIsReadyToInvoke(); const activeTabName = useAppSelector(activeTabNameSelector); + const isProcessing = useAppSelector((state) => state.system.isProcessing); const handleInvoke = useCallback(() => { dispatch(clampSymmetrySteps()); @@ -48,6 +59,7 @@ export default function InvokeButton(props: InvokeButton) { {!isReady && ( @@ -68,13 +80,16 @@ export default function InvokeButton(props: InvokeButton) { icon={} isDisabled={!isReady} onClick={handleInvoke} - flexGrow={1} - w="100%" tooltip={t('parameters.invoke')} tooltipProps={{ placement: 'top' }} colorScheme="accent" id="invoke-button" {...rest} + sx={{ + w: 'full', + flexGrow: 1, + ...(isProcessing ? IN_PROGRESS_STYLES : {}), + }} /> ) : ( Invoke diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index c8e65314da..960a41bb45 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -14,6 +14,7 @@ import { SeedParam, StepsParam, StrengthParam, + VAEParam, WidthParam, } from './parameterZodSchemas'; @@ -47,7 +48,7 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; - shouldUseSeamless: boolean; + vae: VAEParam; seamlessXAxis: boolean; seamlessYAxis: boolean; } @@ -81,9 +82,9 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', - shouldUseSeamless: false, - seamlessXAxis: true, - seamlessYAxis: true, + vae: '', + seamlessXAxis: false, + seamlessYAxis: false, }; const initialState: GenerationState = initialGenerationState; @@ -141,9 +142,6 @@ export const generationSlice = createSlice({ setImg2imgStrength: (state, action: PayloadAction) => { state.img2imgStrength = action.payload; }, - setSeamless: (state, action: PayloadAction) => { - state.shouldUseSeamless = action.payload; - }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -216,6 +214,9 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, + vaeSelected: (state, action: PayloadAction) => { + state.vae = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(configChanged, (state, action) => { @@ -260,8 +261,8 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, + vaeSelected, setShouldUseNoiseSettings, - setSeamless, setSeamlessXAxis, setSeamlessYAxis, } = generationSlice.actions; diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 48eb309e7d..12d77beeb9 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -135,6 +135,15 @@ export const zModel = z.string(); * Type alias for model parameter, inferred from its zod schema */ export type ModelParam = z.infer; +/** + * Zod schema for VAE parameter + * TODO: Make this a dynamically generated enum? + */ +export const zVAE = z.string(); +/** + * Type alias for model parameter, inferred from its zod schema + */ +export type VAEParam = z.infer; /** * Validates/type-guards a value as a model parameter */ diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx deleted file mode 100644 index bd0d0e5d3a..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import { - Button, - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Text, - useDisclosure, -} from '@chakra-ui/react'; - -import IAIButton from 'common/components/IAIButton'; - -import { FaArrowLeft, FaPlus } from 'react-icons/fa'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useTranslation } from 'react-i18next'; - -import type { RootState } from 'app/store/store'; -import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; -import AddCheckpointModel from './AddCheckpointModel'; -import AddDiffusersModel from './AddDiffusersModel'; -import IAIIconButton from 'common/components/IAIIconButton'; - -function AddModelBox({ - text, - onClick, -}: { - text: string; - onClick?: () => void; -}) { - return ( - - {text} - - ); -} - -export default function AddModel() { - const { isOpen, onOpen, onClose } = useDisclosure(); - - const addNewModelUIOption = useAppSelector( - (state: RootState) => state.ui.addNewModelUIOption - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const addModelModalClose = () => { - onClose(); - dispatch(setAddNewModelUIOption(null)); - }; - - return ( - <> - - - - {t('modelManager.addNew')} - - - - - - - {t('modelManager.addNewModel')} - {addNewModelUIOption !== null && ( - dispatch(setAddNewModelUIOption(null))} - position="absolute" - variant="ghost" - zIndex={1} - size="sm" - insetInlineEnd={12} - top={2} - icon={} - /> - )} - - - {addNewModelUIOption == null && ( - - dispatch(setAddNewModelUIOption('ckpt'))} - /> - dispatch(setAddNewModelUIOption('diffusers'))} - /> - - )} - {addNewModelUIOption == 'ckpt' && } - {addNewModelUIOption == 'diffusers' && } - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx deleted file mode 100644 index b860a0848c..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx +++ /dev/null @@ -1,339 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; - -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAINumberInput from 'common/components/IAINumberInput'; -import { useEffect, useState } from 'react'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; - -import { - Flex, - FormControl, - FormLabel, - HStack, - Text, - VStack, -} from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { Field, Formik } from 'formik'; -import { useTranslation } from 'react-i18next'; - -import type { InvokeModelConfigProps } from 'app/types/invokeai'; -import type { RootState } from 'app/store/store'; -import type { FieldInputProps, FormikProps } from 'formik'; -import { isEqual, pickBy } from 'lodash-es'; -import ModelConvert from './ModelConvert'; -import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; -import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; -import IAIForm from 'common/components/IAIForm'; - -const selector = createSelector( - [systemSelector], - (system) => { - const { openModel, model_list } = system; - return { - model_list, - openModel, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -const MIN_MODEL_SIZE = 64; -const MAX_MODEL_SIZE = 2048; - -export default function CheckpointModelEdit() { - const { openModel, model_list } = useAppSelector(selector); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const [editModelFormValues, setEditModelFormValues] = - useState({ - name: '', - description: '', - config: 'configs/stable-diffusion/v1-inference.yaml', - weights: '', - vae: '', - width: 512, - height: 512, - default: false, - format: 'ckpt', - }); - - useEffect(() => { - if (openModel) { - const retrievedModel = pickBy(model_list, (_val, key) => { - return isEqual(key, openModel); - }); - setEditModelFormValues({ - name: openModel, - description: retrievedModel[openModel]?.description, - config: retrievedModel[openModel]?.config, - weights: retrievedModel[openModel]?.weights, - vae: retrievedModel[openModel]?.vae, - width: retrievedModel[openModel]?.width, - height: retrievedModel[openModel]?.height, - default: retrievedModel[openModel]?.default, - format: 'ckpt', - }); - } - }, [model_list, openModel]); - - const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => { - dispatch( - addNewModel({ - ...values, - width: Number(values.width), - height: Number(values.height), - }) - ); - }; - - return openModel ? ( - - - - {openModel} - - - - - - {({ handleSubmit, errors, touched }) => ( - - - {/* Description */} - - - {t('modelManager.description')} - - - - {!!errors.description && touched.description ? ( - - {errors.description} - - ) : ( - - {t('modelManager.descriptionValidationMsg')} - - )} - - - - {/* Config */} - - - {t('modelManager.config')} - - - - {!!errors.config && touched.config ? ( - {errors.config} - ) : ( - - {t('modelManager.configValidationMsg')} - - )} - - - - {/* Weights */} - - - {t('modelManager.modelLocation')} - - - - {!!errors.weights && touched.weights ? ( - - {errors.weights} - - ) : ( - - {t('modelManager.modelLocationValidationMsg')} - - )} - - - - {/* VAE */} - - - {t('modelManager.vaeLocation')} - - - - {!!errors.vae && touched.vae ? ( - {errors.vae} - ) : ( - - {t('modelManager.vaeLocationValidationMsg')} - - )} - - - - - {/* Width */} - - - {t('modelManager.width')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.width && touched.width ? ( - - {errors.width} - - ) : ( - - {t('modelManager.widthValidationMsg')} - - )} - - - - {/* Height */} - - - {t('modelManager.height')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.height && touched.height ? ( - - {errors.height} - - ) : ( - - {t('modelManager.heightValidationMsg')} - - )} - - - - - - {t('modelManager.updateModel')} - - - - )} - - - - ) : ( - - Pick A Model To Edit - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx deleted file mode 100644 index 81998e4976..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx +++ /dev/null @@ -1,281 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; - -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import { useEffect, useState } from 'react'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; - -import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { Field, Formik } from 'formik'; -import { useTranslation } from 'react-i18next'; - -import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai'; -import type { RootState } from 'app/store/store'; -import { isEqual, pickBy } from 'lodash-es'; -import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; -import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; -import IAIForm from 'common/components/IAIForm'; - -const selector = createSelector( - [systemSelector], - (system) => { - const { openModel, model_list } = system; - return { - model_list, - openModel, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -export default function DiffusersModelEdit() { - const { openModel, model_list } = useAppSelector(selector); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const [editModelFormValues, setEditModelFormValues] = - useState({ - name: '', - description: '', - repo_id: '', - path: '', - vae: { repo_id: '', path: '' }, - default: false, - format: 'diffusers', - }); - - useEffect(() => { - if (openModel) { - const retrievedModel = pickBy(model_list, (_val, key) => { - return isEqual(key, openModel); - }); - - setEditModelFormValues({ - name: openModel, - description: retrievedModel[openModel]?.description, - path: - retrievedModel[openModel]?.path && - retrievedModel[openModel]?.path !== 'None' - ? retrievedModel[openModel]?.path - : '', - repo_id: - retrievedModel[openModel]?.repo_id && - retrievedModel[openModel]?.repo_id !== 'None' - ? retrievedModel[openModel]?.repo_id - : '', - vae: { - repo_id: retrievedModel[openModel]?.vae?.repo_id - ? retrievedModel[openModel]?.vae?.repo_id - : '', - path: retrievedModel[openModel]?.vae?.path - ? retrievedModel[openModel]?.vae?.path - : '', - }, - default: retrievedModel[openModel]?.default, - format: 'diffusers', - }); - } - }, [model_list, openModel]); - - const editModelFormSubmitHandler = ( - values: InvokeDiffusersModelConfigProps - ) => { - const diffusersModelToEdit = values; - - if (values.path === '') delete diffusersModelToEdit.path; - if (values.repo_id === '') delete diffusersModelToEdit.repo_id; - if (values.vae.path === '') delete diffusersModelToEdit.vae.path; - if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id; - - dispatch(addNewModel(values)); - }; - - return openModel ? ( - - - - {openModel} - - - - - {({ handleSubmit, errors, touched }) => ( - - - {/* Description */} - - - {t('modelManager.description')} - - - - {!!errors.description && touched.description ? ( - - {errors.description} - - ) : ( - - {t('modelManager.descriptionValidationMsg')} - - )} - - - - {/* Path */} - - - {t('modelManager.modelLocation')} - - - - {!!errors.path && touched.path ? ( - {errors.path} - ) : ( - - {t('modelManager.modelLocationValidationMsg')} - - )} - - - - {/* Repo ID */} - - - {t('modelManager.repo_id')} - - - - {!!errors.repo_id && touched.repo_id ? ( - - {errors.repo_id} - - ) : ( - - {t('modelManager.repoIDValidationMsg')} - - )} - - - - {/* VAE Path */} - - - {t('modelManager.vaeLocation')} - - - - {!!errors.vae?.path && touched.vae?.path ? ( - - {errors.vae?.path} - - ) : ( - - {t('modelManager.vaeLocationValidationMsg')} - - )} - - - - {/* VAE Repo ID */} - - - {t('modelManager.vaeRepoID')} - - - - {!!errors.vae?.repo_id && touched.vae?.repo_id ? ( - - {errors.vae?.repo_id} - - ) : ( - - {t('modelManager.vaeRepoIDValidationMsg')} - - )} - - - - - {t('modelManager.updateModel')} - - - - )} - - - - ) : ( - - Pick A Model To Edit - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx deleted file mode 100644 index 219d49d4ee..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx +++ /dev/null @@ -1,313 +0,0 @@ -import { - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Radio, - RadioGroup, - Text, - Tooltip, - useDisclosure, -} from '@chakra-ui/react'; -// import { mergeDiffusersModels } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAISelect from 'common/components/IAISelect'; -import { diffusersModelsSelector } from 'features/system/store/systemSelectors'; -import { useState } from 'react'; -import { useTranslation } from 'react-i18next'; -import * as InvokeAI from 'app/types/invokeai'; -import IAISlider from 'common/components/IAISlider'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; - -export default function MergeModels() { - const dispatch = useAppDispatch(); - - const { isOpen, onOpen, onClose } = useDisclosure(); - - const diffusersModels = useAppSelector(diffusersModelsSelector); - - const { t } = useTranslation(); - - const [modelOne, setModelOne] = useState( - Object.keys(diffusersModels)[0] - ); - const [modelTwo, setModelTwo] = useState( - Object.keys(diffusersModels)[1] - ); - const [modelThree, setModelThree] = useState('none'); - - const [mergedModelName, setMergedModelName] = useState(''); - const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); - - const [modelMergeInterp, setModelMergeInterp] = useState< - 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' - >('weighted_sum'); - - const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< - 'root' | 'custom' - >('root'); - - const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = - useState(''); - - const [modelMergeForce, setModelMergeForce] = useState(false); - - const modelOneList = Object.keys(diffusersModels).filter( - (model) => model !== modelTwo && model !== modelThree - ); - - const modelTwoList = Object.keys(diffusersModels).filter( - (model) => model !== modelOne && model !== modelThree - ); - - const modelThreeList = [ - { key: t('modelManager.none'), value: 'none' }, - ...Object.keys(diffusersModels) - .filter((model) => model !== modelOne && model !== modelTwo) - .map((model) => ({ key: model, value: model })), - ]; - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const mergeModelsHandler = () => { - let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; - modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); - - const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { - models_to_merge: modelsToMerge, - merged_model_name: - mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), - alpha: modelMergeAlpha, - interp: modelMergeInterp, - model_merge_save_path: - modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, - force: modelMergeForce, - }; - - dispatch(mergeDiffusersModels(mergeModelsInfo)); - }; - - return ( - <> - - - {t('modelManager.mergeModels')} - - - - - - - {t('modelManager.mergeModels')} - - - - - {t('modelManager.modelMergeHeaderHelp1')} - - {t('modelManager.modelMergeHeaderHelp2')} - - - - setModelOne(e.target.value)} - /> - setModelTwo(e.target.value)} - /> - { - if (e.target.value !== 'none') { - setModelThree(e.target.value); - setModelMergeInterp('add_difference'); - } else { - setModelThree('none'); - setModelMergeInterp('weighted_sum'); - } - }} - /> - - - setMergedModelName(e.target.value)} - /> - - - setModelMergeAlpha(v)} - withInput - withReset - handleReset={() => setModelMergeAlpha(0.5)} - withSliderMarks - /> - - {t('modelManager.modelMergeAlphaHelp')} - - - - - - {t('modelManager.interpolationType')} - - setModelMergeInterp(v)} - > - - {modelThree === 'none' ? ( - <> - - - {t('modelManager.weightedSum')} - - - - {t('modelManager.sigmoid')} - - - - {t('modelManager.inverseSigmoid')} - - - - ) : ( - - - - {t('modelManager.addDifference')} - - - - )} - - - - - - - - {t('modelManager.mergedModelSaveLocation')} - - - setModelMergeSaveLocType(v) - } - > - - - - {t('modelManager.invokeAIFolder')} - - - - - {t('modelManager.custom')} - - - - - - {modelMergeSaveLocType === 'custom' && ( - setModelMergeCustomSaveLoc(e.target.value)} - /> - )} - - - setModelMergeForce(e.target.checked)} - fontWeight="500" - /> - - - {t('modelManager.merge')} - - - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx deleted file mode 100644 index 440e5ad4db..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx +++ /dev/null @@ -1,76 +0,0 @@ -import { - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - useDisclosure, -} from '@chakra-ui/react'; -import { cloneElement } from 'react'; - -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { useTranslation } from 'react-i18next'; - -import type { ReactElement } from 'react'; - -import CheckpointModelEdit from './CheckpointModelEdit'; -import DiffusersModelEdit from './DiffusersModelEdit'; -import ModelList from './ModelList'; - -type ModelManagerModalProps = { - children: ReactElement; -}; - -export default function ModelManagerModal({ - children, -}: ModelManagerModalProps) { - const { - isOpen: isModelManagerModalOpen, - onOpen: onModelManagerModalOpen, - onClose: onModelManagerModalClose, - } = useDisclosure(); - - const model_list = useAppSelector( - (state: RootState) => state.system.model_list - ); - - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - - const { t } = useTranslation(); - - return ( - <> - {cloneElement(children, { - onClick: onModelManagerModalOpen, - })} - - - - - {t('modelManager.modelManager')} - - - - {openModel && model_list[openModel]['format'] === 'diffusers' ? ( - - ) : ( - - )} - - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index f9eda624f2..4eeee3e4c6 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -5,10 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { modelSelected } from 'features/parameters/store/generationSlice'; -import { forEach, isString } from 'lodash-es'; import { SelectItem } from '@mantine/core'; import { RootState } from 'app/store/store'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { forEach, isString } from 'lodash-es'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', @@ -23,18 +23,16 @@ const ModelSelect = () => { (state: RootState) => state.generation.model ); - const { data: pipelineModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels, isLoading } = useGetMainModelsQuery(); const data = useMemo(() => { - if (!pipelineModels) { + if (!mainModels) { return []; } const data: SelectItem[] = []; - forEach(pipelineModels.entities, (model, id) => { + forEach(mainModels.entities, (model, id) => { if (!model) { return; } @@ -47,11 +45,11 @@ const ModelSelect = () => { }); return data; - }, [pipelineModels]); + }, [mainModels]); const selectedModel = useMemo( - () => pipelineModels?.entities[selectedModelId], - [pipelineModels?.entities, selectedModelId] + () => mainModels?.entities[selectedModelId], + [mainModels?.entities, selectedModelId] ); const handleChangeModel = useCallback( @@ -65,26 +63,34 @@ const ModelSelect = () => { ); useEffect(() => { - if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + if (selectedModelId && mainModels?.ids.includes(selectedModelId)) { return; } - const firstModel = pipelineModels?.ids[0]; + const firstModel = mainModels?.ids[0]; if (!isString(firstModel)) { return; } handleChangeModel(firstModel); - }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + }, [handleChangeModel, mainModels?.ids, selectedModelId]); - return ( + return isLoading ? ( + + ) : ( 0 ? 'Select a model' : 'No models detected!'} data={data} + error={data.length === 0} onChange={handleChangeModel} /> ); diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 140a8b5978..34bd394214 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -35,6 +35,7 @@ const ProgressBar = () => { aria-label={t('accessibility.invokeProgressBar')} isIndeterminate={isProcessing && !currentStatusHasSteps} height="full" + colorScheme="accent" /> ); }; diff --git a/invokeai/frontend/web/src/features/system/components/SiteHeader.tsx b/invokeai/frontend/web/src/features/system/components/SiteHeader.tsx index 0c94c1d2f9..758f03f19a 100644 --- a/invokeai/frontend/web/src/features/system/components/SiteHeader.tsx +++ b/invokeai/frontend/web/src/features/system/components/SiteHeader.tsx @@ -5,21 +5,18 @@ import StatusIndicator from './StatusIndicator'; import { Link } from '@chakra-ui/react'; import IAIIconButton from 'common/components/IAIIconButton'; import { useTranslation } from 'react-i18next'; -import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa'; +import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa'; import { MdSettings } from 'react-icons/md'; +import { useFeatureStatus } from '../hooks/useFeatureStatus'; +import ColorModeButton from './ColorModeButton'; import HotkeysModal from './HotkeysModal/HotkeysModal'; import InvokeAILogoComponent from './InvokeAILogoComponent'; import LanguagePicker from './LanguagePicker'; -import ModelManagerModal from './ModelManager/ModelManagerModal'; import SettingsModal from './SettingsModal/SettingsModal'; -import { useFeatureStatus } from '../hooks/useFeatureStatus'; -import ColorModeButton from './ColorModeButton'; const SiteHeader = () => { const { t } = useTranslation(); - const isModelManagerEnabled = - useFeatureStatus('modelManager').isFeatureEnabled; const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; @@ -37,20 +34,6 @@ const SiteHeader = () => { - {isModelManagerEnabled && ( - - } - /> - - )} - { const { t } = useTranslation(); - const isModelManagerEnabled = - useFeatureStatus('modelManager').isFeatureEnabled; const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; @@ -27,20 +24,6 @@ const SiteHeaderMenu = () => { flexDirection={{ base: 'column', xl: 'row' }} gap={{ base: 4, xl: 1 }} > - {isModelManagerEnabled && ( - - } - /> - - )} - { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.vae + ); + + const data = useMemo(() => { + if (!vaeModels) { + return []; + } + + const data: SelectItem[] = [ + { + value: 'auto', + label: 'Automatic', + group: 'Default', + }, + ]; + + forEach(vaeModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [vaeModels]); + + const selectedModel = useMemo( + () => vaeModels?.entities[selectedModelId], + [vaeModels?.entities, selectedModelId] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + dispatch(vaeSelected(v)); + }, + [dispatch] + ); + + useEffect(() => { + if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) { + return; + } + handleChangeModel('auto'); + }, [handleChangeModel, vaeModels?.ids, selectedModelId]); + + return ( + + ); +}; + +export default memo(VAESelect); diff --git a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx index 3e2c2153e6..af3eb72d8d 100644 --- a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx +++ b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx @@ -1,13 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; -import { useTranslation } from 'react-i18next'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { setShouldShowGallery } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import { MdPhotoLibrary } from 'react-icons/md'; import { activeTabNameSelector, uiSelector } from '../store/uiSelectors'; -import { memo } from 'react'; +import { NO_GALLERY_TABS } from './InvokeTabs'; const floatingGalleryButtonSelector = createSelector( [activeTabNameSelector, uiSelector], @@ -16,7 +17,9 @@ const floatingGalleryButtonSelector = createSelector( return { shouldPinGallery, - shouldShowGalleryButton: !shouldShowGallery, + shouldShowGalleryButton: NO_GALLERY_TABS.includes(activeTabName) + ? false + : !shouldShowGallery, }; }, { memoizeOptions: { resultEqualityCheck: isEqual } } diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 6bbeedcaaa..c618997f03 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -9,34 +9,35 @@ import { Tooltip, VisuallyHidden, } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; +import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; +import { configSelector } from 'features/system/store/configSelectors'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; -import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react'; +import { ResourceKey } from 'i18next'; +import { isEqual } from 'lodash-es'; +import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { FaCube, FaFont, FaImage } from 'react-icons/fa'; import { MdDeviceHub, MdGridOn } from 'react-icons/md'; +import { Panel, PanelGroup } from 'react-resizable-panels'; +import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize'; import { activeTabIndexSelector, activeTabNameSelector, } from '../store/uiSelectors'; -import { useTranslation } from 'react-i18next'; -import { ResourceKey } from 'i18next'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { createSelector } from '@reduxjs/toolkit'; -import { configSelector } from 'features/system/store/configSelectors'; -import { isEqual } from 'lodash-es'; -import { Panel, PanelGroup } from 'react-resizable-panels'; -import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; +import ImageTab from './tabs/ImageToImage/ImageToImageTab'; +import ModelManagerTab from './tabs/ModelManager/ModelManagerTab'; +import NodesTab from './tabs/Nodes/NodesTab'; +import ResizeHandle from './tabs/ResizeHandle'; import TextToImageTab from './tabs/TextToImage/TextToImageTab'; import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab'; -import NodesTab from './tabs/Nodes/NodesTab'; -import { FaFont, FaImage } from 'react-icons/fa'; -import ResizeHandle from './tabs/ResizeHandle'; -import ImageTab from './tabs/ImageToImage/ImageToImageTab'; -import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator'; -import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize'; export interface InvokeTabInfo { id: InvokeTabName; @@ -65,6 +66,16 @@ const tabs: InvokeTabInfo[] = [ icon: , content: , }, + { + id: 'modelManager', + icon: , + content: , + }, + // { + // id: 'batch', + // icon: , + // content: , + // }, ]; const enabledTabsSelector = createSelector( @@ -81,6 +92,7 @@ const enabledTabsSelector = createSelector( const MIN_GALLERY_WIDTH = 300; const DEFAULT_GALLERY_PCT = 20; +export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager']; const InvokeTabs = () => { const activeTab = useAppSelector(activeTabIndexSelector); @@ -192,26 +204,28 @@ const InvokeTabs = () => { {tabPanels} - {shouldPinGallery && shouldShowGallery && ( - <> - - DEFAULT_GALLERY_PCT - ? galleryMinSizePct - : DEFAULT_GALLERY_PCT - } - minSize={galleryMinSizePct} - maxSize={50} - > - - - - )} + {shouldPinGallery && + shouldShowGallery && + !NO_GALLERY_TABS.includes(activeTabName) && ( + <> + + DEFAULT_GALLERY_PCT + ? galleryMinSizePct + : DEFAULT_GALLERY_PCT + } + minSize={galleryMinSizePct} + maxSize={50} + > + + + + )} ); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersDrawer.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersDrawer.tsx index b41017c2c9..0777463ec4 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersDrawer.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersDrawer.tsx @@ -71,7 +71,15 @@ const ParametersDrawer = () => { onClose={handleClosePanel} > { - - {drawerContent} - + + {drawerContent} + ); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPinnedWrapper.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPinnedWrapper.tsx index d47ca3e1ba..f327e10efc 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPinnedWrapper.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPinnedWrapper.tsx @@ -42,18 +42,10 @@ const ParametersPinnedWrapper = (props: ParametersPinnedWrapperProps) => { h: 'full', w: 'full', position: 'absolute', + overflowY: 'auto', }} > - - - {props.children} - - + {props.children} { }; return ( - - : } - variant="ghost" - size="sm" - sx={{ - color: 'base.700', - _hover: { - color: 'base.550', - }, - _active: { - color: 'base.500', - }, - ...sx, - }} - /> - + : } + variant="ghost" + size="sm" + sx={{ + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + ...sx, + }} + /> ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/Batch/BatchTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/Batch/BatchTab.tsx new file mode 100644 index 0000000000..811660c174 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/Batch/BatchTab.tsx @@ -0,0 +1,43 @@ +import { Box, Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; +import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay'; +import { memo, useCallback, useRef } from 'react'; +import { + ImperativePanelGroupHandle, + Panel, + PanelGroup, +} from 'react-resizable-panels'; +import ResizeHandle from '../ResizeHandle'; +import TextToImageTabMain from '../TextToImage/TextToImageTabMain'; +import BatchManager from 'features/batch/components/BatchManager'; + +const ImageToImageTab = () => { + const dispatch = useAppDispatch(); + const panelGroupRef = useRef(null); + + const handleDoubleClickHandle = useCallback(() => { + if (!panelGroupRef.current) { + return; + } + + panelGroupRef.current.setLayout([50, 50]); + }, []); + + return ( + + + + ); +}; + +export default memo(ImageToImageTab); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index cdbec9b55d..5f5c7ad46b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -1,38 +1,45 @@ -import { memo } from 'react'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; -import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; -import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; -import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; -import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import IAICollapse from 'common/components/IAICollapse'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; +import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; +import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; +import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; +import { uiSelector } from 'features/ui/store/uiSelectors'; +import { memo } from 'react'; const selector = createSelector( [uiSelector, generationSelector], (ui, generation) => { const { shouldUseSliders } = ui; - const { shouldFitToWidthHeight } = generation; + const { shouldFitToWidthHeight, shouldRandomizeSeed } = generation; - return { shouldUseSliders, shouldFitToWidthHeight }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, shouldFitToWidthHeight, activeLabel }; }, defaultSelectorOptions ); const ImageToImageTabCoreParameters = () => { - const { shouldUseSliders, shouldFitToWidthHeight } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, shouldFitToWidthHeight, activeLabel } = + useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -58,7 +65,8 @@ const ImageToImageTabCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx index 4f04abffa1..32b71d6187 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx @@ -1,14 +1,15 @@ -import { memo } from 'react'; -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; -import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; const ImageToImageTabParameters = () => { return ( @@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx new file mode 100644 index 0000000000..8d675b17c8 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -0,0 +1,81 @@ +import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; +import i18n from 'i18n'; +import { ReactNode, memo } from 'react'; +import AddModelsPanel from './subpanels/AddModelsPanel'; +import MergeModelsPanel from './subpanels/MergeModelsPanel'; +import ModelManagerPanel from './subpanels/ModelManagerPanel'; + +type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels'; + +type ModelManagerTabInfo = { + id: ModelManagerTabName; + label: string; + content: ReactNode; +}; + +const modelManagerTabs: ModelManagerTabInfo[] = [ + { + id: 'modelManager', + label: i18n.t('modelManager.modelManager'), + content: , + }, + { + id: 'addModels', + label: i18n.t('modelManager.addModel'), + content: , + }, + { + id: 'mergeModels', + label: i18n.t('modelManager.mergeModels'), + content: , + }, +]; + +const renderTabsList = () => { + const modelManagerTabListsToRender: ReactNode[] = []; + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabListsToRender.push( + {modelManagerTab.label} + ); + }); + + return ( + + {modelManagerTabListsToRender} + + ); +}; + +const renderTabPanels = () => { + const modelManagerTabPanelsToRender: ReactNode[] = []; + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabPanelsToRender.push( + {modelManagerTab.content} + ); + }); + + return {modelManagerTabPanelsToRender}; +}; + +const ModelManagerTab = () => { + return ( + + {renderTabsList()} + {renderTabPanels()} + + ); +}; + +export default memo(ModelManagerTab); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx new file mode 100644 index 0000000000..25f4adf4aa --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx @@ -0,0 +1,55 @@ +import { Divider, Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; +import { useTranslation } from 'react-i18next'; +import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel'; +import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel'; + +export default function AddModelsPanel() { + const addNewModelUIOption = useAppSelector( + (state: RootState) => state.ui.addNewModelUIOption + ); + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + return ( + + + dispatch(setAddNewModelUIOption('ckpt'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addCheckpointModel')} + + dispatch(setAddNewModelUIOption('diffusers'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addDiffuserModel')} + + + + + + {addNewModelUIOption == 'ckpt' && } + {addNewModelUIOption == 'diffusers' && } + + ); +} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx similarity index 99% rename from invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx index e6bd0b6ffb..75e2017bb8 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx @@ -10,13 +10,11 @@ import { } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAIInput from 'common/components/IAIInput'; import IAINumberInput from 'common/components/IAINumberInput'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import React from 'react'; -import SearchModels from './SearchModels'; - // import { addNewModel } from 'app/socketio/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -24,12 +22,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { Field, Formik } from 'formik'; import { useTranslation } from 'react-i18next'; -import type { InvokeModelConfigProps } from 'app/types/invokeai'; import type { RootState } from 'app/store/store'; -import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; -import type { FieldInputProps, FormikProps } from 'formik'; +import type { InvokeModelConfigProps } from 'app/types/invokeai'; import IAIForm from 'common/components/IAIForm'; import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper'; +import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; +import type { FieldInputProps, FormikProps } from 'formik'; +import SearchModels from './SearchModels'; const MIN_MODEL_SIZE = 64; const MAX_MODEL_SIZE = 2048; diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx similarity index 99% rename from invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx index cb3af5f176..dd491828da 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx @@ -66,7 +66,7 @@ export default function AddDiffusersModel() { }; return ( - + value?.model_format === 'diffusers' + ); + + const [modelOne, setModelOne] = useState( + Object.keys(diffusersModels)[0] + ); + const [modelTwo, setModelTwo] = useState( + Object.keys(diffusersModels)[1] + ); + const [modelThree, setModelThree] = useState('none'); + + const [mergedModelName, setMergedModelName] = useState(''); + const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); + + const [modelMergeInterp, setModelMergeInterp] = useState< + 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' + >('weighted_sum'); + + const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< + 'root' | 'custom' + >('root'); + + const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = + useState(''); + + const [modelMergeForce, setModelMergeForce] = useState(false); + + const modelOneList = Object.keys(diffusersModels).filter( + (model) => model !== modelTwo && model !== modelThree + ); + + const modelTwoList = Object.keys(diffusersModels).filter( + (model) => model !== modelOne && model !== modelThree + ); + + const modelThreeList = [ + { key: t('modelManager.none'), value: 'none' }, + ...Object.keys(diffusersModels) + .filter((model) => model !== modelOne && model !== modelTwo) + .map((model) => ({ key: model, value: model })), + ]; + + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const mergeModelsHandler = () => { + let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + + const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { + models_to_merge: modelsToMerge, + merged_model_name: + mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + alpha: modelMergeAlpha, + interp: modelMergeInterp, + model_merge_save_path: + modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + force: modelMergeForce, + }; + + dispatch(mergeDiffusersModels(mergeModelsInfo)); + }; + + return ( + + + {t('modelManager.modelMergeHeaderHelp1')} + + {t('modelManager.modelMergeHeaderHelp2')} + + + + setModelOne(e.target.value)} + /> + setModelTwo(e.target.value)} + /> + { + if (e.target.value !== 'none') { + setModelThree(e.target.value); + setModelMergeInterp('add_difference'); + } else { + setModelThree('none'); + setModelMergeInterp('weighted_sum'); + } + }} + /> + + + setMergedModelName(e.target.value)} + /> + + + setModelMergeAlpha(v)} + withInput + withReset + handleReset={() => setModelMergeAlpha(0.5)} + withSliderMarks + /> + + {t('modelManager.modelMergeAlphaHelp')} + + + + + + {t('modelManager.interpolationType')} + + setModelMergeInterp(v)} + > + + {modelThree === 'none' ? ( + <> + + {t('modelManager.weightedSum')} + + + {t('modelManager.sigmoid')} + + + {t('modelManager.inverseSigmoid')} + + + ) : ( + + + {t('modelManager.addDifference')} + + + )} + + + + + + + + {t('modelManager.mergedModelSaveLocation')} + + setModelMergeSaveLocType(v)} + > + + + {t('modelManager.invokeAIFolder')} + + + + {t('modelManager.custom')} + + + + + + {modelMergeSaveLocType === 'custom' && ( + setModelMergeCustomSaveLoc(e.target.value)} + /> + )} + + + setModelMergeForce(e.target.checked)} + fontWeight="500" + /> + + + {t('modelManager.merge')} + + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx new file mode 100644 index 0000000000..b22a303571 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -0,0 +1,44 @@ +import { Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; + +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; +import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; +import ModelList from './ModelManagerPanel/ModelList'; + +export default function ModelManagerPanel() { + const { data: mainModels } = useGetMainModelsQuery(); + + const openModel = useAppSelector( + (state: RootState) => state.system.openModel + ); + + const renderModelEditTabs = () => { + if (!openModel || !mainModels) return; + + if (mainModels['entities'][openModel]['model_format'] === 'diffusers') { + return ( + + ); + } else { + return ( + + ); + } + }; + return ( + + + {renderModelEditTabs()} + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx new file mode 100644 index 0000000000..0d5d21175a --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -0,0 +1,141 @@ +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; + +import { Divider, Flex, Text } from '@chakra-ui/react'; + +// import { addNewModel } from 'app/socketio/actions'; +import { useForm } from '@mantine/form'; +import { useTranslation } from 'react-i18next'; + +import type { RootState } from 'app/store/store'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { S } from 'services/api/types'; +import ModelConvert from './ModelConvert'; + +const baseModelSelectData = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, +]; + +const variantSelectData = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +export type CheckpointModel = + | S<'StableDiffusion1ModelCheckpointConfig'> + | S<'StableDiffusion2ModelCheckpointConfig'>; + +type CheckpointModelEditProps = { + modelToEdit: string; + retrievedModel: CheckpointModel; +}; + +export default function CheckpointModelEdit(props: CheckpointModelEditProps) { + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const { modelToEdit, retrievedModel } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const checkpointEditForm = useForm({ + initialValues: { + name: retrievedModel.name, + base_model: retrievedModel.base_model, + type: 'main', + path: retrievedModel.path, + description: retrievedModel.description, + model_format: 'checkpoint', + vae: retrievedModel.vae, + config: retrievedModel.config, + variant: retrievedModel.variant, + }, + }); + + const editModelFormSubmitHandler = (values) => { + console.log(values); + }; + + return modelToEdit ? ( + + + + + {retrievedModel.name} + + + {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + + + + + + + +
+ editModelFormSubmitHandler(values) + )} + > + + + + + + + + + + {t('modelManager.updateModel')} + + +
+
+
+ ) : ( + + Pick A Model To Edit + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx new file mode 100644 index 0000000000..6a7b4b3140 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -0,0 +1,125 @@ +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; + +import { Divider, Flex, Text } from '@chakra-ui/react'; + +// import { addNewModel } from 'app/socketio/actions'; +import { useTranslation } from 'react-i18next'; + +import { useForm } from '@mantine/form'; +import type { RootState } from 'app/store/store'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { S } from 'services/api/types'; + +type DiffusersModel = + | S<'StableDiffusion1ModelDiffusersConfig'> + | S<'StableDiffusion2ModelDiffusersConfig'>; + +type DiffusersModelEditProps = { + modelToEdit: string; + retrievedModel: DiffusersModel; +}; + +const baseModelSelectData = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, +]; + +const variantSelectData = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +export default function DiffusersModelEdit(props: DiffusersModelEditProps) { + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + const { retrievedModel, modelToEdit } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const diffusersEditForm = useForm({ + initialValues: { + name: retrievedModel.name, + base_model: retrievedModel.base_model, + type: 'main', + path: retrievedModel.path, + description: retrievedModel.description, + model_format: 'diffusers', + vae: retrievedModel.vae, + variant: retrievedModel.variant, + }, + }); + + const editModelFormSubmitHandler = (values) => { + console.log(values); + }; + + return modelToEdit ? ( + + + + {retrievedModel.name} + + + {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + + + + +
+ editModelFormSubmitHandler(values) + )} + > + + + + + + + + + {t('modelManager.updateModel')} + + +
+
+ ) : ( + + Pick A Model To Edit + + ); +} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx similarity index 84% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 820ad546b3..9f571c2fff 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -4,42 +4,28 @@ import { Radio, RadioGroup, Text, - UnorderedList, Tooltip, + UnorderedList, } from '@chakra-ui/react'; // import { convertToDiffusers } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; -import { useState, useEffect } from 'react'; +import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { CheckpointModel } from './CheckpointModelEdit'; interface ModelConvertProps { - model: string; + model: CheckpointModel; } export default function ModelConvert(props: ModelConvertProps) { const { model } = props; - const model_list = useAppSelector( - (state: RootState) => state.system.model_list - ); - - const retrievedModel = model_list[model]; - const dispatch = useAppDispatch(); const { t } = useTranslation(); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const isConnected = useAppSelector( - (state: RootState) => state.system.isConnected - ); - const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); @@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) { return ( 🧨 {t('modelManager.convertToDiffusers')} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx similarity index 75% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 4ef311e1d4..eb05e70357 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,36 +1,14 @@ -import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react'; -import IAIInput from 'common/components/IAIInput'; +import { Box, Flex, Spinner, Text } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; -import AddModel from './AddModel'; import ModelListItem from './ModelListItem'; -import MergeModels from './MergeModels'; -import { useAppSelector } from 'app/store/storeHooks'; import { useTranslation } from 'react-i18next'; -import { createSelector } from '@reduxjs/toolkit'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import type { SystemState } from 'features/system/store/systemSlice'; -import { isEqual, map } from 'lodash-es'; - -import React, { useMemo, useState, useTransition } from 'react'; import type { ChangeEvent, ReactNode } from 'react'; - -const modelListSelector = createSelector( - systemSelector, - (system: SystemState) => { - const models = map(system.model_list, (model, key) => { - return { name: key, ...model }; - }); - return models; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import React, { useMemo, useState, useTransition } from 'react'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; function ModelFilterButton({ label, @@ -58,7 +36,7 @@ function ModelFilterButton({ } const ModelList = () => { - const models = useAppSelector(modelListSelector); + const { data: mainModels } = useGetMainModelsQuery(); const [renderModelList, setRenderModelList] = React.useState(false); @@ -90,43 +68,49 @@ const ModelList = () => { const filteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = []; - models.forEach((model, i) => { - if (model.name.toLowerCase().includes(searchText.toLowerCase())) { + if (!mainModels) return; + + const modelList = mainModels.entities; + + Object.keys(modelList).forEach((model, i) => { + if ( + modelList[model].name.toLowerCase().includes(searchText.toLowerCase()) + ) { filteredModelListItemsToRender.push( ); - if (model.format === isSelectedFilter) { + if (modelList[model]?.model_format === isSelectedFilter) { localFilteredModelListItemsToRender.push( ); } } - if (model.format !== 'diffusers') { + if (modelList[model]?.model_format !== 'diffusers') { ckptModelListItemsToRender.push( ); } else { diffusersModelListItemsToRender.push( ); } @@ -142,6 +126,23 @@ const ModelList = () => { {isSelectedFilter === 'all' && ( <> + + + {t('modelManager.diffusersModels')} + + {diffusersModelListItemsToRender} + { {ckptModelListItemsToRender} - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - )} - {isSelectedFilter === 'ckpt' && ( - - {ckptModelListItemsToRender} - - )} - {isSelectedFilter === 'diffusers' && ( {diffusersModelListItemsToRender} )} + + {isSelectedFilter === 'ckpt' && ( + + {ckptModelListItemsToRender} + + )} ); - }, [models, searchText, t, isSelectedFilter]); + }, [mainModels, searchText, t, isSelectedFilter]); return ( - - {t('modelManager.availableModels')} - - - - - { { onClick={() => setIsSelectedFilter('all')} isActive={isSelectedFilter === 'all'} /> - setIsSelectedFilter('ckpt')} - isActive={isSelectedFilter === 'ckpt'} - /> setIsSelectedFilter('diffusers')} isActive={isSelectedFilter === 'diffusers'} /> + setIsSelectedFilter('ckpt')} + isActive={isSelectedFilter === 'ckpt'} + /> {renderModelList ? ( diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx similarity index 75% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index aa9f87816c..ab5fddd5ea 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,6 +1,6 @@ import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; -import { ModelStatus } from 'app/types/invokeai'; +import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; + // import { deleteModel, requestModelChange } from 'app/socketio/actions'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; type ModelListItemProps = { + modelKey: string; name: string; - status: ModelStatus; - description: string; + description: string | undefined; }; export default function ModelListItem(props: ModelListItemProps) { @@ -28,39 +28,24 @@ export default function ModelListItem(props: ModelListItemProps) { const dispatch = useAppDispatch(); - const { name, status, description } = props; - - const handleChangeModel = () => { - dispatch(requestModelChange(name)); - }; + const { modelKey, name, description } = props; const openModelHandler = () => { - dispatch(setOpenModel(name)); + dispatch(setOpenModel(modelKey)); }; const handleModelDelete = () => { - dispatch(deleteModel(name)); + dispatch(deleteModel(modelKey)); dispatch(setOpenModel(null)); }; - const statusTextColor = () => { - switch (status) { - case 'active': - return 'ok.500'; - case 'cached': - return 'warning.500'; - case 'not loaded': - return 'inherit'; - } - }; - return ( - {status} - - } size="sm" diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx index 07297bda31..9211e095ba 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx @@ -1,34 +1,41 @@ -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; -import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { memo } from 'react'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; import IAICollapse from 'common/components/IAICollapse'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const TextToImageTabCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -54,7 +61,8 @@ const TextToImageTabCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx index bcc6c91ae6..6291b69a8e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx @@ -1,15 +1,16 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; const TextToImageTabParameters = () => { return ( @@ -18,6 +19,7 @@ const TextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx index 77085bcb75..5474fe8358 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx @@ -14,8 +14,12 @@ import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBe import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta'; import { ImageDTO } from 'services/api/types'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; -import { useDroppable } from '@dnd-kit/core'; import IAIDropOverlay from 'common/components/IAIDropOverlay'; +import { + CanvasInitialImageDropData, + isValidDrop, + useDroppable, +} from 'app/components/ImageDnd/typesafeDnd'; const selector = createSelector( [canvasSelector, uiSelector], @@ -30,28 +34,24 @@ const selector = createSelector( defaultSelectorOptions ); +const droppableData: CanvasInitialImageDropData = { + id: 'canvas-intial-image', + actionType: 'SET_CANVAS_INITIAL_IMAGE', +}; + const UnifiedCanvasContent = () => { const dispatch = useAppDispatch(); const { doesCanvasNeedScaling, shouldUseCanvasBetaLayout } = useAppSelector(selector); - const onDrop = useCallback( - (droppedImage: ImageDTO) => { - dispatch(setInitialCanvasImage(droppedImage)); - }, - [dispatch] - ); - const { isOver, setNodeRef: setDroppableRef, active, } = useDroppable({ id: 'unifiedCanvas', - data: { - handleDrop: onDrop, - }, + data: droppableData, }); useLayoutEffect(() => { @@ -97,7 +97,12 @@ const UnifiedCanvasContent = () => { {doesCanvasNeedScaling ? : } - {active && } + {isValidDrop(droppableData, active) && ( + + )} @@ -139,7 +144,12 @@ const UnifiedCanvasContent = () => { > {doesCanvasNeedScaling ? : } - {active && } + {isValidDrop(droppableData, active) && ( + + )}
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index 42e19eb096..330cd8b31e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -1,35 +1,42 @@ -import { memo } from 'react'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; +import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; -import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth'; -import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight'; -import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import IAICollapse from 'common/components/IAICollapse'; +import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; +import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const UnifiedCanvasCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -55,7 +62,8 @@ const UnifiedCanvasCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index 061ebb962e..63ed4cc1cf 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -1,14 +1,15 @@ -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; -import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; -import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; const UnifiedCanvasParameters = () => { return ( @@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/store/tabMap.ts b/invokeai/frontend/web/src/features/ui/store/tabMap.ts index becf52886e..0cae8eac43 100644 --- a/invokeai/frontend/web/src/features/ui/store/tabMap.ts +++ b/invokeai/frontend/web/src/features/ui/store/tabMap.ts @@ -1,11 +1,10 @@ export const tabMap = [ 'txt2img', 'img2img', - // 'generate', 'unifiedCanvas', 'nodes', - // 'postprocessing', - // 'training', + 'modelManager', + 'batch', ] as const; export type InvokeTabName = (typeof tabMap)[number]; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 38af668cac..861bf49405 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,10 +1,10 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; export const initialUIState: UIState = { activeTab: 0, @@ -19,6 +19,7 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, + shouldShowEmbeddingPicker: false, favoriteSchedulers: [], }; @@ -96,6 +97,9 @@ export const uiSlice = createSlice({ ) => { state.favoriteSchedulers = action.payload; }, + toggleEmbeddingPicker: (state) => { + state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker; + }, }, extraReducers(builder) { builder.addCase(initialImageChanged, (state) => { @@ -122,6 +126,7 @@ export const { toggleGalleryPanel, setShouldShowProgressInViewer, favoriteSchedulersChanged, + toggleEmbeddingPicker, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index d55a1d8fcf..ad0250e56d 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -27,5 +27,6 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; + shouldShowEmbeddingPicker: boolean; favoriteSchedulers: SchedulerParam[]; } diff --git a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts b/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts index cef9ab7cae..a0db3f3dff 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts @@ -1,6 +1,7 @@ import { OffsetPaginatedResults_ImageDTO_ } from 'services/api/types'; import { api } from '..'; import { paths } from '../schema'; +import { imagesApi } from './images'; type ListBoardImagesArg = paths['/api/v1/board_images/{board_id}']['get']['parameters']['path'] & @@ -41,8 +42,22 @@ export const boardImagesApi = api.injectEndpoints({ }), invalidatesTags: (result, error, arg) => [ { type: 'Board', id: arg.board_id }, - { type: 'Image', id: arg.image_name }, ], + async onQueryStarted( + { image_name, ...patch }, + { dispatch, queryFulfilled } + ) { + const patchResult = dispatch( + imagesApi.util.updateQueryData('getImageDTO', image_name, (draft) => { + Object.assign(draft, patch); + }) + ); + try { + await queryFulfilled; + } catch { + patchResult.undo(); + } + }, }), removeImageFromBoard: build.mutation({ @@ -53,8 +68,22 @@ export const boardImagesApi = api.injectEndpoints({ }), invalidatesTags: (result, error, arg) => [ { type: 'Board', id: arg.board_id }, - { type: 'Image', id: arg.image_name }, ], + async onQueryStarted( + { image_name, ...patch }, + { dispatch, queryFulfilled } + ) { + const patchResult = dispatch( + imagesApi.util.updateQueryData('getImageDTO', image_name, (draft) => { + Object.assign(draft, { board_id: null }); + }) + ); + try { + await queryFulfilled; + } catch { + patchResult.undo(); + } + }, }), }), }); diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 39e4e46d3b..a9a914f0f2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,37 +1,85 @@ -import { ModelsList } from 'services/api/types'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; -import { keyBy } from 'lodash-es'; +import { cloneDeep } from 'lodash-es'; +import { + AnyModelConfig, + ControlNetModelConfig, + LoRAModelConfig, + MainModelConfig, + TextualInversionModelConfig, + VaeModelConfig, +} from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; -type ModelConfig = ModelsList['models'][number]; +export type MainModelConfigEntity = MainModelConfig & { id: string }; -type ListModelsArg = NonNullable< - paths['/api/v1/models/']['get']['parameters']['query'] ->; +export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; -const modelsAdapter = createEntityAdapter({ - selectId: (model) => getModelId(model), +export type ControlNetModelConfigEntity = ControlNetModelConfig & { + id: string; +}; + +export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { + id: string; +}; + +export type VaeModelConfigEntity = VaeModelConfig & { id: string }; + +type AnyModelConfigEntity = + | MainModelConfigEntity + | LoRAModelConfigEntity + | ControlNetModelConfigEntity + | TextualInversionModelConfigEntity + | VaeModelConfigEntity; + +const mainModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const loraModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const controlNetModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const textualInversionModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const vaeModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); -const getModelId = ({ base_model, type, name }: ModelConfig) => +export const getModelId = ({ base_model, type, name }: AnyModelConfig) => `${base_model}/${type}/${name}`; +const createModelEntities = ( + models: AnyModelConfig[] +): T[] => { + const entityArray: T[] = []; + models.forEach((model) => { + const entity = { + ...cloneDeep(model), + id: getModelId(model), + } as T; + entityArray.push(entity); + }); + return entityArray; +}; + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - listModels: build.query, ListModelsArg>({ - query: (arg) => ({ url: 'models/', params: arg }), + getMainModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'main' } }), providesTags: (result, error, arg) => { - // any list of boards - const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [ + { id: 'MainModel', type: LIST_TAG }, + ]; if (result) { - // and individual tags for each board tags.push( ...result.ids.map((id) => ({ - type: 'Model' as const, + type: 'MainModel' as const, id, })) ); @@ -39,14 +87,161 @@ export const modelsApi = api.injectEndpoints({ return tags; }, - transformResponse: (response: ModelsList, meta, arg) => { - return modelsAdapter.setAll( - modelsAdapter.getInitialState(), - keyBy(response.models, getModelId) + transformResponse: ( + response: { models: MainModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return mainModelsAdapter.setAll( + mainModelsAdapter.getInitialState(), + entities + ); + }, + }), + getLoRAModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'lora' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'LoRAModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'LoRAModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: LoRAModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return loraModelsAdapter.setAll( + loraModelsAdapter.getInitialState(), + entities + ); + }, + }), + getControlNetModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'ControlNetModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'ControlNetModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: ControlNetModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return controlNetModelsAdapter.setAll( + controlNetModelsAdapter.getInitialState(), + entities + ); + }, + }), + getVaeModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'vae' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'VaeModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'VaeModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: VaeModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return vaeModelsAdapter.setAll( + vaeModelsAdapter.getInitialState(), + entities + ); + }, + }), + getTextualInversionModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'TextualInversionModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'TextualInversionModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: TextualInversionModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return textualInversionModelsAdapter.setAll( + textualInversionModelsAdapter.getInitialState(), + entities ); }, }), }), }); -export const { useListModelsQuery } = modelsApi; +export const { + useGetMainModelsQuery, + useGetControlNetModelsQuery, + useGetLoRAModelsQuery, + useGetTextualInversionModelsQuery, + useGetVaeModelsQuery, +} = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 767fe7b2b3..d7e50d004e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -81,6 +81,13 @@ export type paths = { */ post: operations["update_model"]; }; + "/api/v1/models/import": { + /** + * Import Model + * @description Add a model using its local path, repo_id, or remote URL + */ + post: operations["import_model"]; + }; "/api/v1/models/{model_name}": { /** * Delete Model @@ -227,6 +234,23 @@ export type components = { */ b?: number; }; + /** AddModelResult */ + AddModelResult: { + /** + * Name + * @description The name of the model after import + */ + name: string; + /** @description The type of model */ + model_type: components["schemas"]["ModelType"]; + /** @description The base model */ + base_model: components["schemas"]["BaseModelType"]; + /** + * Config + * @description The configuration of the model + */ + config: components["schemas"]["ModelConfigBase"]; + }; /** * BaseModelType * @description An enumeration. @@ -650,7 +674,7 @@ export type components = { end_step_percent: number; /** * Control Mode - * @description The contorl mode to use + * @description The control mode to use * @default balanced * @enum {string} */ @@ -1030,7 +1054,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; + [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; }; /** * Edges @@ -1073,7 +1097,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; + [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; }; /** * Errors @@ -1276,6 +1300,53 @@ export type components = { */ channel?: "A" | "R" | "G" | "B"; }; + /** + * ImageCollectionInvocation + * @description Load a collection of images and provide it as output. + */ + ImageCollectionInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default image_collection + * @enum {string} + */ + type?: "image_collection"; + /** + * Images + * @description The image collection to load + * @default [] + */ + images?: (components["schemas"]["ImageField"])[]; + }; + /** + * ImageCollectionOutput + * @description A collection of images + */ + ImageCollectionOutput: { + /** + * Type + * @default image_collection + * @enum {string} + */ + type: "image_collection"; + /** + * Collection + * @description The output images + * @default [] + */ + collection: (components["schemas"]["ImageField"])[]; + }; /** * ImageConvertInvocation * @description Converts an image to a different mode. @@ -1928,6 +1999,24 @@ export type components = { */ thumbnail_url: string; }; + /** ImportModelResponse */ + ImportModelResponse: { + /** + * Name + * @description The name of the imported model + */ + name: string; + /** + * Info + * @description The model info + */ + info: components["schemas"]["AddModelResult"]; + /** + * Status + * @description The status of the API response + */ + status: string; + }; /** * InfillColorInvocation * @description Infills transparent areas of an image with a solid color @@ -2440,6 +2529,64 @@ export type components = { */ strength?: number; }; + /** + * LeresImageProcessorInvocation + * @description Applies leres processing to image + */ + LeresImageProcessorInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default leres_image_processor + * @enum {string} + */ + type?: "leres_image_processor"; + /** + * Image + * @description The image to process + */ + image?: components["schemas"]["ImageField"]; + /** + * Thr A + * @description Leres parameter `thr_a` + * @default 0 + */ + thr_a?: number; + /** + * Thr B + * @description Leres parameter `thr_b` + * @default 0 + */ + thr_b?: number; + /** + * Boost + * @description Whether to use boost mode + * @default false + */ + boost?: boolean; + /** + * Detect Resolution + * @description The pixel resolution for detection + * @default 512 + */ + detect_resolution?: number; + /** + * Image Resolution + * @description The pixel resolution for the output image + * @default 512 + */ + image_resolution?: number; + }; /** * LineartAnimeImageProcessorInvocation * @description Applies line art anime processing to image @@ -2543,6 +2690,19 @@ export type components = { model_format: components["schemas"]["LoRAModelFormat"]; error?: components["schemas"]["ModelError"]; }; + /** + * LoRAModelField + * @description LoRA model field + */ + LoRAModelField: { + /** + * Model Name + * @description Name of the LoRA model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** * LoRAModelFormat * @description An enumeration. @@ -2619,10 +2779,10 @@ export type components = { */ type?: "lora_loader"; /** - * Lora Name + * Lora * @description Lora model name */ - lora_name: string; + lora?: components["schemas"]["LoRAModelField"]; /** * Weight * @description With what weight to apply lora @@ -2662,6 +2822,47 @@ export type components = { */ clip?: components["schemas"]["ClipField"]; }; + /** + * MainModelField + * @description Main model field + */ + MainModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; + /** + * MainModelLoaderInvocation + * @description Loads a main model, outputting its submodels. + */ + MainModelLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default main_model_loader + * @enum {string} + */ + type?: "main_model_loader"; + /** + * Model + * @description The model to load + */ + model: components["schemas"]["MainModelField"]; + }; /** * MaskFromAlphaInvocation * @description Extracts the alpha channel of an image as a mask. @@ -2855,6 +3056,16 @@ export type components = { */ thr_d?: number; }; + /** ModelConfigBase */ + ModelConfigBase: { + /** Path */ + path: string; + /** Description */ + description?: string; + /** Model Format */ + model_format?: string; + error?: components["schemas"]["ModelError"]; + }; /** * ModelError * @description An enumeration. @@ -2907,7 +3118,7 @@ export type components = { * @description An enumeration. * @enum {string} */ - ModelType: "pipeline" | "vae" | "lora" | "controlnet" | "embedding"; + ModelType: "main" | "vae" | "lora" | "controlnet" | "embedding"; /** * ModelVariantType * @description An enumeration. @@ -2917,7 +3128,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -2993,12 +3204,6 @@ export type components = { * @default 512 */ height?: number; - /** - * Perlin - * @description The amount of perlin noise to add to the noise - * @default 0 - */ - perlin?: number; /** * Use Cpu * @description Use CPU for noise generation (for reproducible results across platforms) @@ -3312,47 +3517,6 @@ export type components = { */ scribble?: boolean; }; - /** - * PipelineModelField - * @description Pipeline model field - */ - PipelineModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["BaseModelType"]; - }; - /** - * PipelineModelLoaderInvocation - * @description Loads a pipeline model, outputting its submodels. - */ - PipelineModelLoaderInvocation: { - /** - * Id - * @description The id of this node. Must be unique among all nodes. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this node is an intermediate node. - * @default false - */ - is_intermediate?: boolean; - /** - * Type - * @default pipeline_model_loader - * @enum {string} - */ - type?: "pipeline_model_loader"; - /** - * Model - * @description The model to load - */ - model: components["schemas"]["PipelineModelField"]; - }; /** * PromptCollectionOutput * @description Base class for invocations that output a collection of prompts @@ -3697,11 +3861,33 @@ export type components = { antialias?: boolean; }; /** - * SchedulerPredictionType - * @description An enumeration. - * @enum {string} + * SegmentAnythingProcessorInvocation + * @description Applies segment anything processing to image */ - SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; + SegmentAnythingProcessorInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default segment_anything_processor + * @enum {string} + */ + type?: "segment_anything_processor"; + /** + * Image + * @description The image to process + */ + image?: components["schemas"]["ImageField"]; + }; /** * ShowImageInvocation * @description Displays a provided image, and passes it forward in the pipeline. @@ -3739,7 +3925,7 @@ export type components = { * Type * @enum {string} */ - type: "pipeline"; + type: "main"; /** Path */ path: string; /** Description */ @@ -3753,7 +3939,7 @@ export type components = { /** Vae */ vae?: string; /** Config */ - config?: string; + config: string; variant: components["schemas"]["ModelVariantType"]; }; /** StableDiffusion1ModelDiffusersConfig */ @@ -3765,7 +3951,7 @@ export type components = { * Type * @enum {string} */ - type: "pipeline"; + type: "main"; /** Path */ path: string; /** Description */ @@ -3789,7 +3975,7 @@ export type components = { * Type * @enum {string} */ - type: "pipeline"; + type: "main"; /** Path */ path: string; /** Description */ @@ -3803,11 +3989,8 @@ export type components = { /** Vae */ vae?: string; /** Config */ - config?: string; + config: string; variant: components["schemas"]["ModelVariantType"]; - prediction_type: components["schemas"]["SchedulerPredictionType"]; - /** Upcast Attention */ - upcast_attention: boolean; }; /** StableDiffusion2ModelDiffusersConfig */ StableDiffusion2ModelDiffusersConfig: { @@ -3818,7 +4001,7 @@ export type components = { * Type * @enum {string} */ - type: "pipeline"; + type: "main"; /** Path */ path: string; /** Description */ @@ -3832,9 +4015,6 @@ export type components = { /** Vae */ vae?: string; variant: components["schemas"]["ModelVariantType"]; - prediction_type: components["schemas"]["SchedulerPredictionType"]; - /** Upcast Attention */ - upcast_attention: boolean; }; /** * StepParamEasingInvocation @@ -4044,6 +4224,40 @@ export type components = { model_format: null; error?: components["schemas"]["ModelError"]; }; + /** + * TileResamplerProcessorInvocation + * @description Base class for invocations that preprocess images for ControlNet + */ + TileResamplerProcessorInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default tile_image_processor + * @enum {string} + */ + type?: "tile_image_processor"; + /** + * Image + * @description The image to process + */ + image?: components["schemas"]["ImageField"]; + /** + * Down Sampling Rate + * @description Down sampling rate + * @default 1 + */ + down_sampling_rate?: number; + }; /** UNetField */ UNetField: { /** @@ -4103,6 +4317,19 @@ export type components = { */ level?: 2 | 4; }; + /** + * VAEModelField + * @description Vae model field + */ + VAEModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** VaeField */ VaeField: { /** @@ -4111,6 +4338,51 @@ export type components = { */ vae: components["schemas"]["ModelInfo"]; }; + /** + * VaeLoaderInvocation + * @description Loads a VAE model, outputting a VaeLoaderOutput + */ + VaeLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default vae_loader + * @enum {string} + */ + type?: "vae_loader"; + /** + * Vae Model + * @description The VAE to load + */ + vae_model: components["schemas"]["VAEModelField"]; + }; + /** + * VaeLoaderOutput + * @description Model loader output + */ + VaeLoaderOutput: { + /** + * Type + * @default vae_loader_output + * @enum {string} + */ + type?: "vae_loader_output"; + /** + * Vae + * @description Vae model + */ + vae?: components["schemas"]["VaeField"]; + }; /** VaeModelConfig */ VaeModelConfig: { /** Name */ @@ -4189,18 +4461,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -4311,7 +4583,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -4348,7 +4620,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -4592,6 +4864,36 @@ export type operations = { }; }; }; + /** + * Import Model + * @description Add a model using its local path, repo_id, or remote URL + */ + import_model: { + parameters: { + query: { + /** @description A model path, repo_id or URL to import */ + name: string; + /** @description Prediction type for SDv2 checkpoint files */ + prediction_type?: "v_prediction" | "epsilon" | "sample"; + }; + }; + responses: { + /** @description The model imported successfully */ + 201: { + content: { + "application/json": components["schemas"]["ImportModelResponse"]; + }; + }; + /** @description The model could not be found */ + 404: never; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Delete Model * @description Delete Model @@ -4947,6 +5249,10 @@ export type operations = { */ delete_board: { parameters: { + query?: { + /** @description Permanently delete all images on the board */ + include_images?: boolean; + }; path: { /** @description The id of board to delete */ board_id: string; diff --git a/invokeai/frontend/web/src/services/api/thunks/image.ts b/invokeai/frontend/web/src/services/api/thunks/image.ts index a8b3dec5a7..d6e502bc54 100644 --- a/invokeai/frontend/web/src/services/api/thunks/image.ts +++ b/invokeai/frontend/web/src/services/api/thunks/image.ts @@ -1,6 +1,6 @@ import queryString from 'query-string'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { selectImagesAll } from 'features/gallery/store/gallerySlice'; import { size } from 'lodash-es'; import { paths } from 'services/api/schema'; import { $client } from 'services/api/client'; @@ -112,6 +112,10 @@ type UploadedToastAction = { type: 'TOAST_UPLOADED'; }; +type AddToBatchAction = { + type: 'ADD_TO_BATCH'; +}; + export type PostUploadAction = | ControlNetAction | InitialImageAction @@ -119,12 +123,12 @@ export type PostUploadAction = | CanvasInitialImageAction | CanvasMergedAction | CanvasSavedToGalleryAction - | UploadedToastAction; + | UploadedToastAction + | AddToBatchAction; type UploadImageArg = paths['/api/v1/images/']['post']['parameters']['query'] & { file: File; - // file: paths['/api/v1/images/']['post']['requestBody']['content']['multipart/form-data']['file']; postUploadAction?: PostUploadAction; }; @@ -284,8 +288,7 @@ export const receivedPageOfImages = createAppAsyncThunk< const { get } = $client.get(); const state = getState(); - const { categories } = state.images; - const { selectedBoardId } = state.boards; + const { categories, selectedBoardId } = state.gallery; const images = selectImagesAll(state).filter((i) => { const isInCategory = categories.includes(i.image_category); diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 2a2f90f434..3a0bdb71a7 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -4,89 +4,156 @@ import { components } from './schema'; type schemas = components['schemas']; /** - * Extracts the schema type from the schema. + * Marks the `type` property as required. Use for nodes. */ -type S = components['schemas'][T]; - -/** - * Extracts the node type from the schema. - * Also flags the `type` property as required. - */ -type N = O.Required< - components['schemas'][T], - 'type' ->; +type TypeReq = O.Required; // Images -export type ImageDTO = S<'ImageDTO'>; -export type BoardDTO = S<'BoardDTO'>; -export type BoardChanges = S<'BoardChanges'>; -export type ImageChanges = S<'ImageRecordChanges'>; -export type ImageCategory = S<'ImageCategory'>; -export type ResourceOrigin = S<'ResourceOrigin'>; -export type ImageField = S<'ImageField'>; +export type ImageDTO = components['schemas']['ImageDTO']; +export type BoardDTO = components['schemas']['BoardDTO']; +export type BoardChanges = components['schemas']['BoardChanges']; +export type ImageChanges = components['schemas']['ImageRecordChanges']; +export type ImageCategory = components['schemas']['ImageCategory']; +export type ResourceOrigin = components['schemas']['ResourceOrigin']; +export type ImageField = components['schemas']['ImageField']; export type OffsetPaginatedResults_BoardDTO_ = - S<'OffsetPaginatedResults_BoardDTO_'>; + components['schemas']['OffsetPaginatedResults_BoardDTO_']; export type OffsetPaginatedResults_ImageDTO_ = - S<'OffsetPaginatedResults_ImageDTO_'>; + components['schemas']['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = S<'ModelType'>; -export type BaseModelType = S<'BaseModelType'>; -export type PipelineModelField = S<'PipelineModelField'>; -export type ModelsList = S<'ModelsList'>; +export type ModelType = components['schemas']['ModelType']; +export type BaseModelType = components['schemas']['BaseModelType']; +export type MainModelField = components['schemas']['MainModelField']; +export type VAEModelField = components['schemas']['VAEModelField']; +export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ModelsList = components['schemas']['ModelsList']; + +// Model Configs +export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; +export type VaeModelConfig = components['schemas']['VaeModelConfig']; +export type ControlNetModelConfig = + components['schemas']['ControlNetModelConfig']; +export type TextualInversionModelConfig = + components['schemas']['TextualInversionModelConfig']; +export type MainModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type AnyModelConfig = + | LoRAModelConfig + | VaeModelConfig + | ControlNetModelConfig + | TextualInversionModelConfig + | MainModelConfig; // Graphs -export type Graph = S<'Graph'>; -export type Edge = S<'Edge'>; -export type GraphExecutionState = S<'GraphExecutionState'>; +export type Graph = components['schemas']['Graph']; +export type Edge = components['schemas']['Edge']; +export type GraphExecutionState = components['schemas']['GraphExecutionState']; // General nodes -export type CollectInvocation = N<'CollectInvocation'>; -export type IterateInvocation = N<'IterateInvocation'>; -export type RangeInvocation = N<'RangeInvocation'>; -export type RandomRangeInvocation = N<'RandomRangeInvocation'>; -export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>; -export type InpaintInvocation = N<'InpaintInvocation'>; -export type ImageResizeInvocation = N<'ImageResizeInvocation'>; -export type RandomIntInvocation = N<'RandomIntInvocation'>; -export type CompelInvocation = N<'CompelInvocation'>; -export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>; -export type NoiseInvocation = N<'NoiseInvocation'>; -export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>; -export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; -export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; -export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; -export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>; +export type CollectInvocation = TypeReq< + components['schemas']['CollectInvocation'] +>; +export type IterateInvocation = TypeReq< + components['schemas']['IterateInvocation'] +>; +export type RangeInvocation = TypeReq; +export type RandomRangeInvocation = TypeReq< + components['schemas']['RandomRangeInvocation'] +>; +export type RangeOfSizeInvocation = TypeReq< + components['schemas']['RangeOfSizeInvocation'] +>; +export type InpaintInvocation = TypeReq< + components['schemas']['InpaintInvocation'] +>; +export type ImageResizeInvocation = TypeReq< + components['schemas']['ImageResizeInvocation'] +>; +export type RandomIntInvocation = TypeReq< + components['schemas']['RandomIntInvocation'] +>; +export type CompelInvocation = TypeReq< + components['schemas']['CompelInvocation'] +>; +export type DynamicPromptInvocation = TypeReq< + components['schemas']['DynamicPromptInvocation'] +>; +export type NoiseInvocation = TypeReq; +export type TextToLatentsInvocation = TypeReq< + components['schemas']['TextToLatentsInvocation'] +>; +export type LatentsToLatentsInvocation = TypeReq< + components['schemas']['LatentsToLatentsInvocation'] +>; +export type ImageToLatentsInvocation = TypeReq< + components['schemas']['ImageToLatentsInvocation'] +>; +export type LatentsToImageInvocation = TypeReq< + components['schemas']['LatentsToImageInvocation'] +>; +export type ImageCollectionInvocation = TypeReq< + components['schemas']['ImageCollectionInvocation'] +>; +export type MainModelLoaderInvocation = TypeReq< + components['schemas']['MainModelLoaderInvocation'] +>; +export type LoraLoaderInvocation = TypeReq< + components['schemas']['LoraLoaderInvocation'] +>; // ControlNet Nodes -export type ControlNetInvocation = N<'ControlNetInvocation'>; -export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>; -export type ContentShuffleImageProcessorInvocation = - N<'ContentShuffleImageProcessorInvocation'>; -export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>; -export type LineartAnimeImageProcessorInvocation = - N<'LineartAnimeImageProcessorInvocation'>; -export type LineartImageProcessorInvocation = - N<'LineartImageProcessorInvocation'>; -export type MediapipeFaceProcessorInvocation = - N<'MediapipeFaceProcessorInvocation'>; -export type MidasDepthImageProcessorInvocation = - N<'MidasDepthImageProcessorInvocation'>; -export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>; -export type NormalbaeImageProcessorInvocation = - N<'NormalbaeImageProcessorInvocation'>; -export type OpenposeImageProcessorInvocation = - N<'OpenposeImageProcessorInvocation'>; -export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>; -export type ZoeDepthImageProcessorInvocation = - N<'ZoeDepthImageProcessorInvocation'>; +export type ControlNetInvocation = TypeReq< + components['schemas']['ControlNetInvocation'] +>; +export type CannyImageProcessorInvocation = TypeReq< + components['schemas']['CannyImageProcessorInvocation'] +>; +export type ContentShuffleImageProcessorInvocation = TypeReq< + components['schemas']['ContentShuffleImageProcessorInvocation'] +>; +export type HedImageProcessorInvocation = TypeReq< + components['schemas']['HedImageProcessorInvocation'] +>; +export type LineartAnimeImageProcessorInvocation = TypeReq< + components['schemas']['LineartAnimeImageProcessorInvocation'] +>; +export type LineartImageProcessorInvocation = TypeReq< + components['schemas']['LineartImageProcessorInvocation'] +>; +export type MediapipeFaceProcessorInvocation = TypeReq< + components['schemas']['MediapipeFaceProcessorInvocation'] +>; +export type MidasDepthImageProcessorInvocation = TypeReq< + components['schemas']['MidasDepthImageProcessorInvocation'] +>; +export type MlsdImageProcessorInvocation = TypeReq< + components['schemas']['MlsdImageProcessorInvocation'] +>; +export type NormalbaeImageProcessorInvocation = TypeReq< + components['schemas']['NormalbaeImageProcessorInvocation'] +>; +export type OpenposeImageProcessorInvocation = TypeReq< + components['schemas']['OpenposeImageProcessorInvocation'] +>; +export type PidiImageProcessorInvocation = TypeReq< + components['schemas']['PidiImageProcessorInvocation'] +>; +export type ZoeDepthImageProcessorInvocation = TypeReq< + components['schemas']['ZoeDepthImageProcessorInvocation'] +>; // Node Outputs -export type ImageOutput = S<'ImageOutput'>; -export type MaskOutput = S<'MaskOutput'>; -export type PromptOutput = S<'PromptOutput'>; -export type IterateInvocationOutput = S<'IterateInvocationOutput'>; -export type CollectInvocationOutput = S<'CollectInvocationOutput'>; -export type LatentsOutput = S<'LatentsOutput'>; -export type GraphInvocationOutput = S<'GraphInvocationOutput'>; +export type ImageOutput = components['schemas']['ImageOutput']; +export type MaskOutput = components['schemas']['MaskOutput']; +export type PromptOutput = components['schemas']['PromptOutput']; +export type IterateInvocationOutput = + components['schemas']['IterateInvocationOutput']; +export type CollectInvocationOutput = + components['schemas']['CollectInvocationOutput']; +export type LatentsOutput = components['schemas']['LatentsOutput']; +export type GraphInvocationOutput = + components['schemas']['GraphInvocationOutput']; diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index 85641b88a0..665761a626 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -1,18 +1,18 @@ import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; -import { io, Socket } from 'socket.io-client'; +import { Socket, io } from 'socket.io-client'; +import { AppThunkDispatch, RootState } from 'app/store/store'; +import { getTimestamp } from 'common/util/getTimestamp'; +import { sessionCreated } from 'services/api/thunks/session'; import { ClientToServerEvents, ServerToClientEvents, } from 'services/events/types'; import { socketSubscribed, socketUnsubscribed } from './actions'; -import { AppThunkDispatch, RootState } from 'app/store/store'; -import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCreated } from 'services/api/thunks/session'; // import { OpenAPI } from 'services/api/types'; -import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; import { $authToken, $baseUrl } from 'services/api/client'; +import { setEventListeners } from 'services/events/util/setEventListeners'; const socketioLog = log.child({ namespace: 'socketio' }); @@ -88,7 +88,7 @@ export const socketMiddleware = () => { socketSubscribed({ sessionId: sessionId, timestamp: getTimestamp(), - boardId: getState().boards.selectedBoardId, + boardId: getState().gallery.selectedBoardId, }) ); } diff --git a/invokeai/frontend/web/src/theme/components/button.ts b/invokeai/frontend/web/src/theme/components/button.ts index 75662f7d42..7bb8a39a71 100644 --- a/invokeai/frontend/web/src/theme/components/button.ts +++ b/invokeai/frontend/web/src/theme/components/button.ts @@ -7,10 +7,10 @@ const invokeAI = defineStyle((props) => { if (c === 'base') { const _disabled = { - bg: mode('base.200', 'base.700')(props), - color: mode('base.500', 'base.150')(props), + bg: mode('base.150', 'base.700')(props), + color: mode('base.500', 'base.500')(props), svg: { - fill: mode('base.500', 'base.150')(props), + fill: mode('base.500', 'base.500')(props), }, opacity: 1, }; @@ -30,7 +30,6 @@ const invokeAI = defineStyle((props) => { 'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))' )(props), }, - _disabled, _hover: { bg: mode('base.300', 'base.500')(props), color: mode('base.900', 'base.50')(props), @@ -39,34 +38,16 @@ const invokeAI = defineStyle((props) => { }, _disabled, }, - _checked: { - bg: mode('accent.400', 'accent.600')(props), - color: mode('base.50', 'base.100')(props), - svg: { - fill: mode(`${c}.50`, `${c}.100`)(props), - filter: mode( - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`, - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))` - )(props), - }, - _disabled, - _hover: { - bg: mode('accent.500', 'accent.500')(props), - color: mode('white', 'base.50')(props), - svg: { - fill: mode('white', 'base.50')(props), - }, - _disabled, - }, - }, + _disabled, }; } const _disabled = { - bg: mode(`${c}.200`, `${c}.700`)(props), - color: mode(`${c}.100`, `${c}.150`)(props), + bg: mode(`${c}.250`, `${c}.700`)(props), + color: mode(`${c}.50`, `${c}.500`)(props), svg: { - fill: mode(`${c}.100`, `${c}.150`)(props), + fill: mode(`${c}.50`, `${c}.500`)(props), + filter: 'unset', }, opacity: 1, filter: mode(undefined, 'saturate(65%)')(props), @@ -78,7 +59,7 @@ const invokeAI = defineStyle((props) => { borderRadius: 'base', textShadow: mode( `0 0 0.3rem var(--invokeai-colors-${c}-600)`, - `0 0 0.3rem var(--invokeai-colors-${c}-900)` + `0 0 0.3rem var(--invokeai-colors-${c}-800)` )(props), svg: { fill: mode(`base.50`, `base.100`)(props), @@ -96,26 +77,6 @@ const invokeAI = defineStyle((props) => { }, _disabled, }, - _checked: { - bg: mode('accent.400', 'accent.600')(props), - color: mode('base.50', 'base.100')(props), - svg: { - fill: mode(`base.50`, `base.100`)(props), - filter: mode( - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`, - `drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))` - )(props), - }, - _disabled, - _hover: { - bg: mode('accent.500', 'accent.500')(props), - color: mode('white', 'base.50')(props), - svg: { - fill: mode('white', 'base.50')(props), - }, - _disabled, - }, - }, }; }); diff --git a/invokeai/frontend/web/src/theme/components/menu.ts b/invokeai/frontend/web/src/theme/components/menu.ts index 02f75087ed..324720a040 100644 --- a/invokeai/frontend/web/src/theme/components/menu.ts +++ b/invokeai/frontend/web/src/theme/components/menu.ts @@ -22,6 +22,8 @@ const invokeAI = definePartsStyle((props) => ({ list: { zIndex: 9999, bg: mode('base.200', 'base.800')(props), + shadow: 'dark-lg', + border: 'none', }, item: { // this will style the MenuItem and MenuItemOption components diff --git a/invokeai/frontend/web/src/theme/components/progress.ts b/invokeai/frontend/web/src/theme/components/progress.ts index 87b6b7af01..71231869ce 100644 --- a/invokeai/frontend/web/src/theme/components/progress.ts +++ b/invokeai/frontend/web/src/theme/components/progress.ts @@ -3,24 +3,19 @@ import { createMultiStyleConfigHelpers, defineStyle, } from '@chakra-ui/styled-system'; +import { mode } from '@chakra-ui/theme-tools'; const { defineMultiStyleConfig, definePartsStyle } = createMultiStyleConfigHelpers(parts.keys); const invokeAIFilledTrack = defineStyle((_props) => ({ - bg: 'accent.600', - // TODO: the animation is nice but looks weird bc it is substantially longer than each step - // so we get to 100% long before it finishes - // transition: 'width 0.2s ease-in-out', - _indeterminate: { - bgGradient: - 'linear(to-r, transparent 0%, accent.600 50%, transparent 100%);', - }, + bg: 'accentAlpha.500', })); const invokeAITrack = defineStyle((_props) => { + const { colorScheme: c } = _props; return { - bg: 'none', + bg: mode(`${c}.200`, `${c}.700`)(_props), }; }); diff --git a/invokeai/frontend/web/src/theme/components/skeleton.ts b/invokeai/frontend/web/src/theme/components/skeleton.ts new file mode 100644 index 0000000000..8ee97e0fb8 --- /dev/null +++ b/invokeai/frontend/web/src/theme/components/skeleton.ts @@ -0,0 +1,25 @@ +import { defineStyle, defineStyleConfig, cssVar } from '@chakra-ui/react'; + +const $startColor = cssVar('skeleton-start-color'); +const $endColor = cssVar('skeleton-end-color'); + +const invokeAI = defineStyle({ + borderRadius: 'base', + maxW: 'full', + maxH: 'full', + _light: { + [$startColor.variable]: 'colors.base.250', + [$endColor.variable]: 'colors.base.450', + }, + _dark: { + [$startColor.variable]: 'colors.base.700', + [$endColor.variable]: 'colors.base.500', + }, +}); + +export const skeletonTheme = defineStyleConfig({ + variants: { invokeAI }, + defaultProps: { + variant: 'invokeAI', + }, +}); diff --git a/invokeai/frontend/web/src/theme/components/textarea.ts b/invokeai/frontend/web/src/theme/components/textarea.ts index 85e6e37d3f..b737cf5e57 100644 --- a/invokeai/frontend/web/src/theme/components/textarea.ts +++ b/invokeai/frontend/web/src/theme/components/textarea.ts @@ -1,7 +1,28 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { getInputOutlineStyles } from '../util/getInputOutlineStyles'; -const invokeAI = defineStyle((props) => getInputOutlineStyles(props)); +const invokeAI = defineStyle((props) => ({ + ...getInputOutlineStyles(props), + '::-webkit-scrollbar': { + display: 'initial', + }, + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-50) 0%, + var(--invokeai-colors-base-50) 70%, + var(--invokeai-colors-base-200) 70%, + var(--invokeai-colors-base-200) 100%)`, + }, + _dark: { + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-900) 0%, + var(--invokeai-colors-base-900) 70%, + var(--invokeai-colors-base-800) 70%, + var(--invokeai-colors-base-800) 100%)`, + }, + }, +})); export const textareaTheme = defineStyleConfig({ variants: { diff --git a/invokeai/frontend/web/src/theme/theme.ts b/invokeai/frontend/web/src/theme/theme.ts index 76b4aaaacc..03d1f640ac 100644 --- a/invokeai/frontend/web/src/theme/theme.ts +++ b/invokeai/frontend/web/src/theme/theme.ts @@ -19,6 +19,7 @@ import { tabsTheme } from './components/tabs'; import { textTheme } from './components/text'; import { textareaTheme } from './components/textarea'; import { tooltipTheme } from './components/tooltip'; +import { skeletonTheme } from './components/skeleton'; export const theme: ThemeOverride = { config: { @@ -68,6 +69,11 @@ export const theme: ThemeOverride = { working: `0 0 7px var(--invokeai-colors-working-400)`, error: `0 0 7px var(--invokeai-colors-error-400)`, }, + selected: { + light: + '0px 0px 0px 1px var(--invokeai-colors-base-150), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', + dark: '0px 0px 0px 1px var(--invokeai-colors-base-900), 0px 0px 0px 4px var(--invokeai-colors-accent-400)', + }, nodeSelectedOutline: `0 0 0 2px var(--invokeai-colors-base-500)`, }, colors: InvokeAIColors, @@ -82,6 +88,7 @@ export const theme: ThemeOverride = { Switch: switchTheme, NumberInput: numberInputTheme, Select: selectTheme, + Skeleton: skeletonTheme, Slider: sliderTheme, Popover: popoverTheme, Modal: modalTheme, diff --git a/invokeai/frontend/web/stats.html b/invokeai/frontend/web/stats.html index dc999e13df..7c7df1671a 100644 --- a/invokeai/frontend/web/stats.html +++ b/invokeai/frontend/web/stats.html @@ -145,9 +145,9 @@ main { var drawChart = (function (exports) { 'use strict'; - var n,l$1,u$1,t$1,o$2,r$1,f$1={},e$1=[],c$1=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i;function s$1(n,l){for(var u in l)n[u]=l[u];return n}function a$1(n){var l=n.parentNode;l&&l.removeChild(n);}function h$1(l,u,i){var t,o,r,f={};for(r in u)"key"==r?t=u[r]:"ref"==r?o=u[r]:f[r]=u[r];if(arguments.length>2&&(f.children=arguments.length>3?n.call(arguments,2):i),"function"==typeof l&&null!=l.defaultProps)for(r in l.defaultProps)void 0===f[r]&&(f[r]=l.defaultProps[r]);return v$1(l,f,t,o,null)}function v$1(n,i,t,o,r){var f={type:n,props:i,key:t,ref:o,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,__h:null,constructor:void 0,__v:null==r?++u$1:r};return null==r&&null!=l$1.vnode&&l$1.vnode(f),f}function p$1(n){return n.children}function d$1(n,l){this.props=n,this.context=l;}function _$2(n,l){if(null==l)return n.__?_$2(n.__,n.__.__k.indexOf(n)+1):null;for(var u;l0?v$1(k.type,k.props,k.key,k.ref?k.ref:null,k.__v):k)){if(k.__=u,k.__b=u.__b+1,null===(d=x[h])||d&&k.key==d.key&&k.type===d.type)x[h]=void 0;else for(y=0;y2&&(f.children=arguments.length>3?n.call(arguments,2):i),"function"==typeof l&&null!=l.defaultProps)for(r in l.defaultProps)void 0===f[r]&&(f[r]=l.defaultProps[r]);return d$1(l,f,t,o,null)}function d$1(n,i,t,o,r){var f={type:n,props:i,key:t,ref:o,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,__h:null,constructor:void 0,__v:null==r?++u$1:r};return null==r&&null!=l$1.vnode&&l$1.vnode(f),f}function k$1(n){return n.children}function b$1(n,l){this.props=n,this.context=l;}function g$1(n,l){if(null==l)return n.__?g$1(n.__,n.__.__k.indexOf(n)+1):null;for(var u;ll&&t$1.sort(f$1));x.__r=0;}function P(n,l,u,i,t,o,r,f,e,a){var h,p,y,_,b,m,w,x=i&&i.__k||s$1,P=x.length;for(u.__k=[],h=0;h0?d$1(_.type,_.props,_.key,_.ref?_.ref:null,_.__v):_)){if(_.__=u,_.__b=u.__b+1,null===(y=x[h])||y&&_.key==y.key&&_.type===y.type)x[h]=void 0;else for(p=0;p=0;l--)if((u=n.__k[l])&&(i=A(u)))return i;return null}function H(n,l,u,i,t){var o;for(o in u)"children"===o||"key"===o||o in l||T$1(n,o,null,u[o],i);for(o in l)t&&"function"!=typeof l[o]||"children"===o||"key"===o||"value"===o||"checked"===o||u[o]===l[o]||T$1(n,o,l[o],u[o],i);}function I(n,l,u){"-"===l[0]?n.setProperty(l,null==u?"":u):n[l]=null==u?"":"number"!=typeof u||a$1.test(l)?u:u+"px";}function T$1(n,l,u,i,t){var o;n:if("style"===l)if("string"==typeof u)n.style.cssText=u;else {if("string"==typeof i&&(n.style.cssText=i=""),i)for(l in i)u&&l in u||I(n.style,l,"");if(u)for(l in u)i&&u[l]===i[l]||I(n.style,l,u[l]);}else if("o"===l[0]&&"n"===l[1])o=l!==(l=l.replace(/Capture$/,"")),l=l.toLowerCase()in n?l.toLowerCase().slice(2):l.slice(2),n.l||(n.l={}),n.l[l+o]=u,u?i||n.addEventListener(l,o?z$1:j$1,o):n.removeEventListener(l,o?z$1:j$1,o);else if("dangerouslySetInnerHTML"!==l){if(t)l=l.replace(/xlink(H|:h)/,"h").replace(/sName$/,"s");else if("width"!==l&&"height"!==l&&"href"!==l&&"list"!==l&&"form"!==l&&"tabIndex"!==l&&"download"!==l&&"rowSpan"!==l&&"colSpan"!==l&&l in n)try{n[l]=null==u?"":u;break n}catch(n){}"function"==typeof u||(null==u||!1===u&&"-"!==l[4]?n.removeAttribute(l):n.setAttribute(l,u));}}function j$1(n){return this.l[n.type+!1](l$1.event?l$1.event(n):n)}function z$1(n){return this.l[n.type+!0](l$1.event?l$1.event(n):n)}function L(n,u,i,t,o,r,f,e,c){var s,a,p,y,d,_,g,m,w,x,C,S,$,A,H,I=u.type;if(void 0!==u.constructor)return null;null!=i.__h&&(c=i.__h,e=u.__e=i.__e,u.__h=null,r=[e]),(s=l$1.__b)&&s(u);try{n:if("function"==typeof I){if(m=u.props,w=(s=I.contextType)&&t[s.__c],x=s?w?w.props.value:s.__:t,i.__c?g=(a=u.__c=i.__c).__=a.__E:("prototype"in I&&I.prototype.render?u.__c=a=new I(m,x):(u.__c=a=new b$1(m,x),a.constructor=I,a.render=B$1),w&&w.sub(a),a.props=m,a.state||(a.state={}),a.context=x,a.__n=t,p=a.__d=!0,a.__h=[],a._sb=[]),null==a.__s&&(a.__s=a.state),null!=I.getDerivedStateFromProps&&(a.__s==a.state&&(a.__s=h$1({},a.__s)),h$1(a.__s,I.getDerivedStateFromProps(m,a.__s))),y=a.props,d=a.state,a.__v=u,p)null==I.getDerivedStateFromProps&&null!=a.componentWillMount&&a.componentWillMount(),null!=a.componentDidMount&&a.__h.push(a.componentDidMount);else {if(null==I.getDerivedStateFromProps&&m!==y&&null!=a.componentWillReceiveProps&&a.componentWillReceiveProps(m,x),!a.__e&&null!=a.shouldComponentUpdate&&!1===a.shouldComponentUpdate(m,a.__s,x)||u.__v===i.__v){for(u.__v!==i.__v&&(a.props=m,a.state=a.__s,a.__d=!1),a.__e=!1,u.__e=i.__e,u.__k=i.__k,u.__k.forEach(function(n){n&&(n.__=u);}),C=0;C=i.__.length&&i.__.push({__V:c}),i.__[t]}function p(n){return o=1,y(B$1,n)}function y(n,u,i){var o=d(t++,2);if(o.t=n,!o.__c&&(o.__=[i?i(u):B$1(void 0,u),function(n){var t=o.__N?o.__N[0]:o.__[0],r=o.t(t,n);t!==r&&(o.__N=[r,o.__[1]],o.__c.setState({}));}],o.__c=r,!r.u)){r.u=!0;var f=r.shouldComponentUpdate;r.shouldComponentUpdate=function(n,t,r){if(!o.__c.__H)return !0;var u=o.__c.__H.__.filter(function(n){return n.__c});if(u.every(function(n){return !n.__N}))return !f||f.call(this,n,t,r);var i=!1;return u.forEach(function(n){if(n.__N){var t=n.__[0];n.__=n.__N,n.__N=void 0,t!==n.__[0]&&(i=!0);}}),!(!i&&o.__c.props===n)&&(!f||f.call(this,n,t,r))};}return o.__N||o.__}function h(u,i){var o=d(t++,3);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__H.__h.push(o));}function s(u,i){var o=d(t++,4);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__h.push(o));}function _(n){return o=5,F(function(){return {current:n}},[])}function F(n,r){var u=d(t++,7);return z(u.__H,r)?(u.__V=n(),u.i=r,u.__h=n,u.__V):u.__}function T(n,t){return o=8,F(function(){return n},t)}function q(n){var u=r.context[n.__c],i=d(t++,9);return i.c=n,u?(null==i.__&&(i.__=!0,u.sub(r)),u.props.value):n.__}function b(){for(var t;t=f.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(k),t.__H.__h.forEach(w),t.__H.__h=[];}catch(r){t.__H.__h=[],l$1.__e(r,t.__v);}}l$1.__b=function(n){r=null,e&&e(n);},l$1.__r=function(n){a&&a(n),t=0;var i=(r=n.__c).__H;i&&(u===r?(i.__h=[],r.__h=[],i.__.forEach(function(n){n.__N&&(n.__=n.__N),n.__V=c,n.__N=n.i=void 0;})):(i.__h.forEach(k),i.__h.forEach(w),i.__h=[])),u=r;},l$1.diffed=function(t){v&&v(t);var o=t.__c;o&&o.__H&&(o.__H.__h.length&&(1!==f.push(o)&&i===l$1.requestAnimationFrame||((i=l$1.requestAnimationFrame)||j)(b)),o.__H.__.forEach(function(n){n.i&&(n.__H=n.i),n.__V!==c&&(n.__=n.__V),n.i=void 0,n.__V=c;})),u=r=null;},l$1.__c=function(t,r){r.some(function(t){try{t.__h.forEach(k),t.__h=t.__h.filter(function(n){return !n.__||w(n)});}catch(u){r.some(function(n){n.__h&&(n.__h=[]);}),r=[],l$1.__e(u,t.__v);}}),l&&l(t,r);},l$1.unmount=function(t){m&&m(t);var r,u=t.__c;u&&u.__H&&(u.__H.__.forEach(function(n){try{k(n);}catch(n){r=n;}}),u.__H=void 0,r&&l$1.__e(r,u.__v));};var g="function"==typeof requestAnimationFrame;function j(n){var t,r=function(){clearTimeout(u),g&&cancelAnimationFrame(t),setTimeout(n);},u=setTimeout(r,100);g&&(t=requestAnimationFrame(r));}function k(n){var t=r,u=n.__c;"function"==typeof u&&(n.__c=void 0,u()),r=t;}function w(n){var t=r;n.__c=n.__(),r=t;}function z(n,t){return !n||n.length!==t.length||t.some(function(t,r){return t!==n[r]})}function B$1(n,t){return "function"==typeof t?t(n):t} + var t,r,u,i,o=0,f=[],c=[],e=l$1.__b,a=l$1.__r,v=l$1.diffed,l=l$1.__c,m=l$1.unmount;function d(t,u){l$1.__h&&l$1.__h(r,t,o||u),o=0;var i=r.__H||(r.__H={__:[],__h:[]});return t>=i.__.length&&i.__.push({__V:c}),i.__[t]}function h(n){return o=1,s(B,n)}function s(n,u,i){var o=d(t++,2);if(o.t=n,!o.__c&&(o.__=[i?i(u):B(void 0,u),function(n){var t=o.__N?o.__N[0]:o.__[0],r=o.t(t,n);t!==r&&(o.__N=[r,o.__[1]],o.__c.setState({}));}],o.__c=r,!r.u)){var f=function(n,t,r){if(!o.__c.__H)return !0;var u=o.__c.__H.__.filter(function(n){return n.__c});if(u.every(function(n){return !n.__N}))return !c||c.call(this,n,t,r);var i=!1;return u.forEach(function(n){if(n.__N){var t=n.__[0];n.__=n.__N,n.__N=void 0,t!==n.__[0]&&(i=!0);}}),!(!i&&o.__c.props===n)&&(!c||c.call(this,n,t,r))};r.u=!0;var c=r.shouldComponentUpdate,e=r.componentWillUpdate;r.componentWillUpdate=function(n,t,r){if(this.__e){var u=c;c=void 0,f(n,t,r),c=u;}e&&e.call(this,n,t,r);},r.shouldComponentUpdate=f;}return o.__N||o.__}function p(u,i){var o=d(t++,3);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__H.__h.push(o));}function y(u,i){var o=d(t++,4);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__h.push(o));}function _(n){return o=5,F(function(){return {current:n}},[])}function F(n,r){var u=d(t++,7);return z(u.__H,r)?(u.__V=n(),u.i=r,u.__h=n,u.__V):u.__}function T(n,t){return o=8,F(function(){return n},t)}function q(n){var u=r.context[n.__c],i=d(t++,9);return i.c=n,u?(null==i.__&&(i.__=!0,u.sub(r)),u.props.value):n.__}function b(){for(var t;t=f.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(k),t.__H.__h.forEach(w),t.__H.__h=[];}catch(r){t.__H.__h=[],l$1.__e(r,t.__v);}}l$1.__b=function(n){r=null,e&&e(n);},l$1.__r=function(n){a&&a(n),t=0;var i=(r=n.__c).__H;i&&(u===r?(i.__h=[],r.__h=[],i.__.forEach(function(n){n.__N&&(n.__=n.__N),n.__V=c,n.__N=n.i=void 0;})):(i.__h.forEach(k),i.__h.forEach(w),i.__h=[],t=0)),u=r;},l$1.diffed=function(t){v&&v(t);var o=t.__c;o&&o.__H&&(o.__H.__h.length&&(1!==f.push(o)&&i===l$1.requestAnimationFrame||((i=l$1.requestAnimationFrame)||j)(b)),o.__H.__.forEach(function(n){n.i&&(n.__H=n.i),n.__V!==c&&(n.__=n.__V),n.i=void 0,n.__V=c;})),u=r=null;},l$1.__c=function(t,r){r.some(function(t){try{t.__h.forEach(k),t.__h=t.__h.filter(function(n){return !n.__||w(n)});}catch(u){r.some(function(n){n.__h&&(n.__h=[]);}),r=[],l$1.__e(u,t.__v);}}),l&&l(t,r);},l$1.unmount=function(t){m&&m(t);var r,u=t.__c;u&&u.__H&&(u.__H.__.forEach(function(n){try{k(n);}catch(n){r=n;}}),u.__H=void 0,r&&l$1.__e(r,u.__v));};var g="function"==typeof requestAnimationFrame;function j(n){var t,r=function(){clearTimeout(u),g&&cancelAnimationFrame(t),setTimeout(n);},u=setTimeout(r,100);g&&(t=requestAnimationFrame(r));}function k(n){var t=r,u=n.__c;"function"==typeof u&&(n.__c=void 0,u()),r=t;}function w(n){var t=r;n.__c=n.__(),r=t;}function z(n,t){return !n||n.length!==t.length||t.some(function(t,r){return t!==n[r]})}function B(n,t){return "function"==typeof t?t(n):t} const PLACEHOLDER = "bundle-*:**/file/**,**/file**, bundle-*:"; const SideBar = ({ availableSizeProperties, sizeProperty, setSizeProperty, onExcludeChange, onIncludeChange, }) => { - const [includeValue, setIncludeValue] = p(""); - const [excludeValue, setExcludeValue] = p(""); + const [includeValue, setIncludeValue] = h(""); + const [excludeValue, setExcludeValue] = h(""); const handleSizePropertyChange = (sizeProp) => () => { if (sizeProp !== sizeProperty) { setSizeProperty(sizeProp); @@ -682,23 +680,17 @@ var drawChart = (function (exports) { setExcludeValue(value); onExcludeChange(value); }; - return (o$1("aside", Object.assign({ className: "sidebar" }, { children: [o$1("div", Object.assign({ className: "size-selectors" }, { children: availableSizeProperties.length > 1 && + return (o$1("aside", { className: "sidebar", children: [o$1("div", { className: "size-selectors", children: availableSizeProperties.length > 1 && availableSizeProperties.map((sizeProp) => { const id = `selector-${sizeProp}`; - return (o$1("div", Object.assign({ className: "size-selector" }, { children: [o$1("input", { type: "radio", id: id, checked: sizeProp === sizeProperty, onChange: handleSizePropertyChange(sizeProp) }), o$1("label", Object.assign({ htmlFor: id }, { children: LABELS[sizeProp] }))] }), sizeProp)); - }) })), o$1("div", Object.assign({ className: "module-filters" }, { children: [o$1("div", Object.assign({ className: "module-filter" }, { children: [o$1("label", Object.assign({ htmlFor: "module-filter-exclude" }, { children: "Exclude" })), o$1("input", { type: "text", id: "module-filter-exclude", value: excludeValue, onInput: handleExcludeChange, placeholder: PLACEHOLDER })] })), o$1("div", Object.assign({ className: "module-filter" }, { children: [o$1("label", Object.assign({ htmlFor: "module-filter-include" }, { children: "Include" })), o$1("input", { type: "text", id: "module-filter-include", value: includeValue, onInput: handleIncludeChange, placeholder: PLACEHOLDER })] }))] }))] }))); + return (o$1("div", { className: "size-selector", children: [o$1("input", { type: "radio", id: id, checked: sizeProp === sizeProperty, onChange: handleSizePropertyChange(sizeProp) }), o$1("label", { htmlFor: id, children: LABELS[sizeProp] })] }, sizeProp)); + }) }), o$1("div", { className: "module-filters", children: [o$1("div", { className: "module-filter", children: [o$1("label", { htmlFor: "module-filter-exclude", children: "Exclude" }), o$1("input", { type: "text", id: "module-filter-exclude", value: excludeValue, onInput: handleExcludeChange, placeholder: PLACEHOLDER })] }), o$1("div", { className: "module-filter", children: [o$1("label", { htmlFor: "module-filter-include", children: "Include" }), o$1("input", { type: "text", id: "module-filter-include", value: includeValue, onInput: handleIncludeChange, placeholder: PLACEHOLDER })] })] })] })); }; function getDefaultExportFromCjs (x) { return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; } - var picomatchBrowserExports = {}; - var picomatchBrowser = { - get exports(){ return picomatchBrowserExports; }, - set exports(v){ picomatchBrowserExports = v; }, - }; - var utils$3 = {}; const WIN_SLASH = '\\\\/'; @@ -941,7 +933,7 @@ var drawChart = (function (exports) { } else { return path.replace(/\/$/, '').replace(/.*\//, ''); } - }; + }; } (utils$3)); const utils$2 = utils$3; @@ -2738,12 +2730,9 @@ var drawChart = (function (exports) { var picomatch_1 = picomatch; - (function (module) { + var picomatchBrowser = picomatch_1; - module.exports = picomatch_1; - } (picomatchBrowser)); - - var pm = /*@__PURE__*/getDefaultExportFromCjs(picomatchBrowserExports); + var pm = /*@__PURE__*/getDefaultExportFromCjs(picomatchBrowser); function isArray(arg) { return Array.isArray(arg); @@ -2834,8 +2823,8 @@ var drawChart = (function (exports) { })); }; const useFilter = () => { - const [includeFilter, setIncludeFilter] = p(""); - const [excludeFilter, setExcludeFilter] = p(""); + const [includeFilter, setIncludeFilter] = h(""); + const [excludeFilter, setExcludeFilter] = h(""); const setIncludeFilterTrottled = F(() => throttleFilter(setIncludeFilter, 200), []); const setExcludeFilterTrottled = F(() => throttleFilter(setExcludeFilter, 200), []); const isIncluded = F(() => createFilter(prepareFilter(includeFilter), prepareFilter(excludeFilter)), [includeFilter, excludeFilter]); @@ -2924,6 +2913,7 @@ var drawChart = (function (exports) { const ascendingBisect = bisector(ascending); const bisectRight = ascendingBisect.right; bisector(number$1).center; + var bisect = bisectRight; class InternMap extends Map { constructor(entries, key = keyof) { @@ -2997,59 +2987,60 @@ var drawChart = (function (exports) { })(values, 0); } - var e10 = Math.sqrt(50), + const e10 = Math.sqrt(50), e5 = Math.sqrt(10), e2 = Math.sqrt(2); - function ticks(start, stop, count) { - var reverse, - i = -1, - n, - ticks, - step; - - stop = +stop, start = +start, count = +count; - if (start === stop && count > 0) return [start]; - if (reverse = stop < start) n = start, start = stop, stop = n; - if ((step = tickIncrement(start, stop, count)) === 0 || !isFinite(step)) return []; - - if (step > 0) { - let r0 = Math.round(start / step), r1 = Math.round(stop / step); - if (r0 * step < start) ++r0; - if (r1 * step > stop) --r1; - ticks = new Array(n = r1 - r0 + 1); - while (++i < n) ticks[i] = (r0 + i) * step; + function tickSpec(start, stop, count) { + const step = (stop - start) / Math.max(0, count), + power = Math.floor(Math.log10(step)), + error = step / Math.pow(10, power), + factor = error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1; + let i1, i2, inc; + if (power < 0) { + inc = Math.pow(10, -power) / factor; + i1 = Math.round(start * inc); + i2 = Math.round(stop * inc); + if (i1 / inc < start) ++i1; + if (i2 / inc > stop) --i2; + inc = -inc; } else { - step = -step; - let r0 = Math.round(start * step), r1 = Math.round(stop * step); - if (r0 / step < start) ++r0; - if (r1 / step > stop) --r1; - ticks = new Array(n = r1 - r0 + 1); - while (++i < n) ticks[i] = (r0 + i) / step; + inc = Math.pow(10, power) * factor; + i1 = Math.round(start / inc); + i2 = Math.round(stop / inc); + if (i1 * inc < start) ++i1; + if (i2 * inc > stop) --i2; } + if (i2 < i1 && 0.5 <= count && count < 2) return tickSpec(start, stop, count * 2); + return [i1, i2, inc]; + } - if (reverse) ticks.reverse(); - + function ticks(start, stop, count) { + stop = +stop, start = +start, count = +count; + if (!(count > 0)) return []; + if (start === stop) return [start]; + const reverse = stop < start, [i1, i2, inc] = reverse ? tickSpec(stop, start, count) : tickSpec(start, stop, count); + if (!(i2 >= i1)) return []; + const n = i2 - i1 + 1, ticks = new Array(n); + if (reverse) { + if (inc < 0) for (let i = 0; i < n; ++i) ticks[i] = (i2 - i) / -inc; + else for (let i = 0; i < n; ++i) ticks[i] = (i2 - i) * inc; + } else { + if (inc < 0) for (let i = 0; i < n; ++i) ticks[i] = (i1 + i) / -inc; + else for (let i = 0; i < n; ++i) ticks[i] = (i1 + i) * inc; + } return ticks; } function tickIncrement(start, stop, count) { - var step = (stop - start) / Math.max(0, count), - power = Math.floor(Math.log(step) / Math.LN10), - error = step / Math.pow(10, power); - return power >= 0 - ? (error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1) * Math.pow(10, power) - : -Math.pow(10, -power) / (error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1); + stop = +stop, start = +start, count = +count; + return tickSpec(start, stop, count)[2]; } function tickStep(start, stop, count) { - var step0 = Math.abs(stop - start) / Math.max(0, count), - step1 = Math.pow(10, Math.floor(Math.log(step0) / Math.LN10)), - error = step0 / step1; - if (error >= e10) step1 *= 10; - else if (error >= e5) step1 *= 5; - else if (error >= e2) step1 *= 2; - return stop < start ? -step1 : step1; + stop = +stop, start = +start, count = +count; + const reverse = stop < start, inc = reverse ? tickIncrement(stop, start, count) : tickIncrement(start, stop, count); + return (reverse ? -1 : 1) * (inc < 0 ? 1 / -inc : inc); } const TOP_PADDING = 20; @@ -3075,7 +3066,7 @@ var drawChart = (function (exports) { else { textProps.y = height / 2; } - s(() => { + y(() => { if (width == 0 || height == 0 || !textRef.current) { return; } @@ -3100,18 +3091,18 @@ var drawChart = (function (exports) { if (width == 0 || height == 0) { return null; } - return (o$1("g", Object.assign({ className: "node", transform: `translate(${x0},${y0})`, onClick: (event) => { + return (o$1("g", { className: "node", transform: `translate(${x0},${y0})`, onClick: (event) => { event.stopPropagation(); onClick(node); }, onMouseOver: (event) => { event.stopPropagation(); onMouseOver(node); - } }, { children: [o$1("rect", { fill: backgroundColor, rx: 2, ry: 2, width: x1 - x0, height: y1 - y0, stroke: selected ? "#fff" : undefined, "stroke-width": selected ? 2 : undefined }), o$1("text", Object.assign({ ref: textRef, fill: fontColor, onClick: (event) => { + }, children: [o$1("rect", { fill: backgroundColor, rx: 2, ry: 2, width: x1 - x0, height: y1 - y0, stroke: selected ? "#fff" : undefined, "stroke-width": selected ? 2 : undefined }), o$1("text", Object.assign({ ref: textRef, fill: fontColor, onClick: (event) => { var _a; if (((_a = window.getSelection()) === null || _a === void 0 ? void 0 : _a.toString()) !== "") { event.stopPropagation(); } - } }, textProps, { children: data.name }))] }))); + } }, textProps, { children: data.name }))] })); }; const TreeMap = ({ root, onNodeHover, selectedNode, onNodeClick, }) => { @@ -3128,18 +3119,14 @@ var drawChart = (function (exports) { return nestedData; }, [root]); console.timeEnd("layering"); - return (o$1("svg", Object.assign({ xmlns: "http://www.w3.org/2000/svg", viewBox: `0 0 ${width} ${height}` }, { children: nestedData.map(({ key, values }) => { - return (o$1("g", Object.assign({ className: "layer" }, { children: values.map((node) => { + return (o$1("svg", { xmlns: "http://www.w3.org/2000/svg", viewBox: `0 0 ${width} ${height}`, children: nestedData.map(({ key, values }) => { + return (o$1("g", { className: "layer", children: values.map((node) => { return (o$1(Node, { node: node, onMouseOver: onNodeHover, selected: selectedNode === node, onClick: onNodeClick }, getModuleIds(node.data).nodeUid.id)); - }) }), key)); - }) }))); + }) }, key)); + }) })); }; - var bytesExports = {}; - var bytes$1 = { - get exports(){ return bytesExports; }, - set exports(v){ bytesExports = v; }, - }; + var bytes$1 = {exports: {}}; /*! * bytes @@ -3154,8 +3141,8 @@ var drawChart = (function (exports) { */ bytes$1.exports = bytes; - var format_1 = bytesExports.format = format$1; - bytesExports.parse = parse; + var format_1 = bytes$1.exports.format = format$1; + bytes$1.exports.parse = parse; /** * Module variables. @@ -3318,7 +3305,7 @@ var drawChart = (function (exports) { const Tooltip = ({ node, visible, root, sizeProperty, }) => { const { availableSizeProperties, getModuleSize, data } = q(StaticContext); const ref = _(null); - const [style, setStyle] = p({}); + const [style, setStyle] = h({}); const content = F(() => { if (!node) return null; @@ -3336,7 +3323,7 @@ var drawChart = (function (exports) { const mainUid = data.nodeParts[node.data.uid].metaUid; dataNode = data.nodeMetas[mainUid]; } - return (o$1(p$1, { children: [o$1("div", { children: path }), availableSizeProperties.map((sizeProp) => { + return (o$1(k$1, { children: [o$1("div", { children: path }), availableSizeProperties.map((sizeProp) => { if (sizeProp === sizeProperty) { return (o$1("div", { children: [o$1("b", { children: [LABELS[sizeProp], ": ", format_1(mainSize)] }), " ", "(", percentageString, ")"] }, sizeProp)); } @@ -3346,7 +3333,7 @@ var drawChart = (function (exports) { }), o$1("br", {}), dataNode && dataNode.importedBy.length > 0 && (o$1("div", { children: [o$1("div", { children: [o$1("b", { children: "Imported By" }), ":"] }), dataNode.importedBy.map(({ uid }) => { const id = data.nodeMetas[uid].id; return o$1("div", { children: id }, id); - })] })), o$1("br", {}), o$1("small", { children: data.options.sourcemap ? SOURCEMAP_RENDERED : RENDRED }), (data.options.gzip || data.options.brotli) && (o$1(p$1, { children: [o$1("br", {}), o$1("small", { children: COMPRESSED })] }))] })); + })] })), o$1("br", {}), o$1("small", { children: data.options.sourcemap ? SOURCEMAP_RENDERED : RENDRED }), (data.options.gzip || data.options.brotli) && (o$1(k$1, { children: [o$1("br", {}), o$1("small", { children: COMPRESSED })] }))] })); }, [availableSizeProperties, data, getModuleSize, node, root.data, sizeProperty]); const updatePosition = (mouseCoords) => { if (!ref.current) @@ -3366,7 +3353,7 @@ var drawChart = (function (exports) { } setStyle(pos); }; - h(() => { + p(() => { const handleMouseMove = (event) => { updatePosition({ x: event.pageX, @@ -3378,13 +3365,13 @@ var drawChart = (function (exports) { document.removeEventListener("mousemove", handleMouseMove, true); }; }, []); - return (o$1("div", Object.assign({ className: `tooltip ${visible ? "" : "tooltip-hidden"}`, ref: ref, style: style }, { children: content }))); + return (o$1("div", { className: `tooltip ${visible ? "" : "tooltip-hidden"}`, ref: ref, style: style, children: content })); }; const Chart = ({ root, sizeProperty, selectedNode, setSelectedNode, }) => { - const [showTooltip, setShowTooltip] = p(false); - const [tooltipNode, setTooltipNode] = p(undefined); - h(() => { + const [showTooltip, setShowTooltip] = h(false); + const [tooltipNode, setTooltipNode] = h(undefined); + p(() => { const handleMouseOut = () => { setShowTooltip(false); }; @@ -3393,7 +3380,7 @@ var drawChart = (function (exports) { document.removeEventListener("mouseover", handleMouseOut); }; }, []); - return (o$1(p$1, { children: [o$1(TreeMap, { root: root, onNodeHover: (node) => { + return (o$1(k$1, { children: [o$1(TreeMap, { root: root, onNodeHover: (node) => { setTooltipNode(node); setShowTooltip(true); }, selectedNode: selectedNode, onNodeClick: (node) => { @@ -3403,8 +3390,8 @@ var drawChart = (function (exports) { const Main = () => { const { availableSizeProperties, rawHierarchy, getModuleSize, layout, data } = q(StaticContext); - const [sizeProperty, setSizeProperty] = p(availableSizeProperties[0]); - const [selectedNode, setSelectedNode] = p(undefined); + const [sizeProperty, setSizeProperty] = h(availableSizeProperties[0]); + const [selectedNode, setSelectedNode] = h(undefined); const { getModuleFilterMultiplier, setExcludeFilter, setIncludeFilter } = useFilter(); console.time("getNodeSizeMultiplier"); const getNodeSizeMultiplier = F(() => { @@ -3459,7 +3446,7 @@ var drawChart = (function (exports) { sizeProperty, ]); console.timeEnd("root hierarchy compute"); - return (o$1(p$1, { children: [o$1(SideBar, { sizeProperty: sizeProperty, availableSizeProperties: availableSizeProperties, setSizeProperty: setSizeProperty, onExcludeChange: setExcludeFilter, onIncludeChange: setIncludeFilter }), o$1(Chart, { root: root, sizeProperty: sizeProperty, selectedNode: selectedNode, setSelectedNode: setSelectedNode })] })); + return (o$1(k$1, { children: [o$1(SideBar, { sizeProperty: sizeProperty, availableSizeProperties: availableSizeProperties, setSizeProperty: setSizeProperty, onExcludeChange: setExcludeFilter, onIncludeChange: setIncludeFilter }), o$1(Chart, { root: root, sizeProperty: sizeProperty, selectedNode: selectedNode, setSelectedNode: setSelectedNode })] })); }; function initRange(domain, range) { @@ -3895,179 +3882,6 @@ var drawChart = (function (exports) { : m1) * 255; } - const radians = Math.PI / 180; - const degrees = 180 / Math.PI; - - // https://observablehq.com/@mbostock/lab-and-rgb - const K = 18, - Xn = 0.96422, - Yn = 1, - Zn = 0.82521, - t0$1 = 4 / 29, - t1$1 = 6 / 29, - t2 = 3 * t1$1 * t1$1, - t3 = t1$1 * t1$1 * t1$1; - - function labConvert(o) { - if (o instanceof Lab) return new Lab(o.l, o.a, o.b, o.opacity); - if (o instanceof Hcl) return hcl2lab(o); - if (!(o instanceof Rgb)) o = rgbConvert(o); - var r = rgb2lrgb(o.r), - g = rgb2lrgb(o.g), - b = rgb2lrgb(o.b), - y = xyz2lab((0.2225045 * r + 0.7168786 * g + 0.0606169 * b) / Yn), x, z; - if (r === g && g === b) x = z = y; else { - x = xyz2lab((0.4360747 * r + 0.3850649 * g + 0.1430804 * b) / Xn); - z = xyz2lab((0.0139322 * r + 0.0971045 * g + 0.7141733 * b) / Zn); - } - return new Lab(116 * y - 16, 500 * (x - y), 200 * (y - z), o.opacity); - } - - function lab(l, a, b, opacity) { - return arguments.length === 1 ? labConvert(l) : new Lab(l, a, b, opacity == null ? 1 : opacity); - } - - function Lab(l, a, b, opacity) { - this.l = +l; - this.a = +a; - this.b = +b; - this.opacity = +opacity; - } - - define(Lab, lab, extend(Color, { - brighter(k) { - return new Lab(this.l + K * (k == null ? 1 : k), this.a, this.b, this.opacity); - }, - darker(k) { - return new Lab(this.l - K * (k == null ? 1 : k), this.a, this.b, this.opacity); - }, - rgb() { - var y = (this.l + 16) / 116, - x = isNaN(this.a) ? y : y + this.a / 500, - z = isNaN(this.b) ? y : y - this.b / 200; - x = Xn * lab2xyz(x); - y = Yn * lab2xyz(y); - z = Zn * lab2xyz(z); - return new Rgb( - lrgb2rgb( 3.1338561 * x - 1.6168667 * y - 0.4906146 * z), - lrgb2rgb(-0.9787684 * x + 1.9161415 * y + 0.0334540 * z), - lrgb2rgb( 0.0719453 * x - 0.2289914 * y + 1.4052427 * z), - this.opacity - ); - } - })); - - function xyz2lab(t) { - return t > t3 ? Math.pow(t, 1 / 3) : t / t2 + t0$1; - } - - function lab2xyz(t) { - return t > t1$1 ? t * t * t : t2 * (t - t0$1); - } - - function lrgb2rgb(x) { - return 255 * (x <= 0.0031308 ? 12.92 * x : 1.055 * Math.pow(x, 1 / 2.4) - 0.055); - } - - function rgb2lrgb(x) { - return (x /= 255) <= 0.04045 ? x / 12.92 : Math.pow((x + 0.055) / 1.055, 2.4); - } - - function hclConvert(o) { - if (o instanceof Hcl) return new Hcl(o.h, o.c, o.l, o.opacity); - if (!(o instanceof Lab)) o = labConvert(o); - if (o.a === 0 && o.b === 0) return new Hcl(NaN, 0 < o.l && o.l < 100 ? 0 : NaN, o.l, o.opacity); - var h = Math.atan2(o.b, o.a) * degrees; - return new Hcl(h < 0 ? h + 360 : h, Math.sqrt(o.a * o.a + o.b * o.b), o.l, o.opacity); - } - - function hcl(h, c, l, opacity) { - return arguments.length === 1 ? hclConvert(h) : new Hcl(h, c, l, opacity == null ? 1 : opacity); - } - - function Hcl(h, c, l, opacity) { - this.h = +h; - this.c = +c; - this.l = +l; - this.opacity = +opacity; - } - - function hcl2lab(o) { - if (isNaN(o.h)) return new Lab(o.l, 0, 0, o.opacity); - var h = o.h * radians; - return new Lab(o.l, Math.cos(h) * o.c, Math.sin(h) * o.c, o.opacity); - } - - define(Hcl, hcl, extend(Color, { - brighter(k) { - return new Hcl(this.h, this.c, this.l + K * (k == null ? 1 : k), this.opacity); - }, - darker(k) { - return new Hcl(this.h, this.c, this.l - K * (k == null ? 1 : k), this.opacity); - }, - rgb() { - return hcl2lab(this).rgb(); - } - })); - - var A = -0.14861, - B = +1.78277, - C = -0.29227, - D = -0.90649, - E = +1.97294, - ED = E * D, - EB = E * B, - BC_DA = B * C - D * A; - - function cubehelixConvert(o) { - if (o instanceof Cubehelix) return new Cubehelix(o.h, o.s, o.l, o.opacity); - if (!(o instanceof Rgb)) o = rgbConvert(o); - var r = o.r / 255, - g = o.g / 255, - b = o.b / 255, - l = (BC_DA * b + ED * r - EB * g) / (BC_DA + ED - EB), - bl = b - l, - k = (E * (g - l) - C * bl) / D, - s = Math.sqrt(k * k + bl * bl) / (E * l * (1 - l)), // NaN if l=0 or l=1 - h = s ? Math.atan2(k, bl) * degrees - 120 : NaN; - return new Cubehelix(h < 0 ? h + 360 : h, s, l, o.opacity); - } - - function cubehelix$1(h, s, l, opacity) { - return arguments.length === 1 ? cubehelixConvert(h) : new Cubehelix(h, s, l, opacity == null ? 1 : opacity); - } - - function Cubehelix(h, s, l, opacity) { - this.h = +h; - this.s = +s; - this.l = +l; - this.opacity = +opacity; - } - - define(Cubehelix, cubehelix$1, extend(Color, { - brighter(k) { - k = k == null ? brighter : Math.pow(brighter, k); - return new Cubehelix(this.h, this.s, this.l * k, this.opacity); - }, - darker(k) { - k = k == null ? darker : Math.pow(darker, k); - return new Cubehelix(this.h, this.s, this.l * k, this.opacity); - }, - rgb() { - var h = isNaN(this.h) ? 0 : (this.h + 120) * radians, - l = +this.l, - a = isNaN(this.s) ? 0 : this.s * l * (1 - l), - cosh = Math.cos(h), - sinh = Math.sin(h); - return new Rgb( - 255 * (l + a * (A * cosh + B * sinh)), - 255 * (l + a * (C * cosh + D * sinh)), - 255 * (l + a * (E * cosh)), - this.opacity - ); - } - })); - var constant = x => () => x; function linear$1(a, d) { @@ -4082,11 +3896,6 @@ var drawChart = (function (exports) { }; } - function hue(a, b) { - var d = b - a; - return d ? linear$1(a, d > 180 || d < -180 ? d - 360 * Math.round(d / 360) : d) : constant(isNaN(a) ? b : a); - } - function gamma(y) { return (y = +y) === 1 ? nogamma : function(a, b) { return b - a ? exponential(a, b, y) : constant(isNaN(a) ? b : a); @@ -4268,105 +4077,6 @@ var drawChart = (function (exports) { }; } - var epsilon2 = 1e-12; - - function cosh(x) { - return ((x = Math.exp(x)) + 1 / x) / 2; - } - - function sinh(x) { - return ((x = Math.exp(x)) - 1 / x) / 2; - } - - function tanh(x) { - return ((x = Math.exp(2 * x)) - 1) / (x + 1); - } - - ((function zoomRho(rho, rho2, rho4) { - - // p0 = [ux0, uy0, w0] - // p1 = [ux1, uy1, w1] - function zoom(p0, p1) { - var ux0 = p0[0], uy0 = p0[1], w0 = p0[2], - ux1 = p1[0], uy1 = p1[1], w1 = p1[2], - dx = ux1 - ux0, - dy = uy1 - uy0, - d2 = dx * dx + dy * dy, - i, - S; - - // Special case for u0 ≅ u1. - if (d2 < epsilon2) { - S = Math.log(w1 / w0) / rho; - i = function(t) { - return [ - ux0 + t * dx, - uy0 + t * dy, - w0 * Math.exp(rho * t * S) - ]; - }; - } - - // General case. - else { - var d1 = Math.sqrt(d2), - b0 = (w1 * w1 - w0 * w0 + rho4 * d2) / (2 * w0 * rho2 * d1), - b1 = (w1 * w1 - w0 * w0 - rho4 * d2) / (2 * w1 * rho2 * d1), - r0 = Math.log(Math.sqrt(b0 * b0 + 1) - b0), - r1 = Math.log(Math.sqrt(b1 * b1 + 1) - b1); - S = (r1 - r0) / rho; - i = function(t) { - var s = t * S, - coshr0 = cosh(r0), - u = w0 / (rho2 * d1) * (coshr0 * tanh(rho * s + r0) - sinh(r0)); - return [ - ux0 + u * dx, - uy0 + u * dy, - w0 * coshr0 / cosh(rho * s + r0) - ]; - }; - } - - i.duration = S * 1000 * rho / Math.SQRT2; - - return i; - } - - zoom.rho = function(_) { - var _1 = Math.max(1e-3, +_), _2 = _1 * _1, _4 = _2 * _2; - return zoomRho(_1, _2, _4); - }; - - return zoom; - }))(Math.SQRT2, 2, 4); - - function cubehelix(hue) { - return (function cubehelixGamma(y) { - y = +y; - - function cubehelix(start, end) { - var h = hue((start = cubehelix$1(start)).h, (end = cubehelix$1(end)).h), - s = nogamma(start.s, end.s), - l = nogamma(start.l, end.l), - opacity = nogamma(start.opacity, end.opacity); - return function(t) { - start.h = h(t); - start.s = s(t); - start.l = l(Math.pow(t, y)); - start.opacity = opacity(t); - return start + ""; - }; - } - - cubehelix.gamma = cubehelixGamma; - - return cubehelix; - })(1); - } - - cubehelix(hue); - cubehelix(nogamma); - function constants(x) { return function() { return x; @@ -4422,7 +4132,7 @@ var drawChart = (function (exports) { } return function(x) { - var i = bisectRight(domain, x, 1, j) - 1; + var i = bisect(domain, x, 1, j) - 1; return r[i](d[i](x)); }; } @@ -4658,7 +4368,7 @@ var drawChart = (function (exports) { var map = Array.prototype.map, prefixes = ["y","z","a","f","p","n","µ","m","","k","M","G","T","P","E","Z","Y"]; - function formatLocale$1(locale) { + function formatLocale(locale) { var group = locale.grouping === undefined || locale.thousands === undefined ? identity : formatGroup(map.call(locale.grouping, Number), locale.thousands + ""), currencyPrefix = locale.currency === undefined ? "" : locale.currency[0] + "", currencySuffix = locale.currency === undefined ? "" : locale.currency[1] + "", @@ -4795,21 +4505,21 @@ var drawChart = (function (exports) { }; } - var locale$1; + var locale; var format; var formatPrefix; - defaultLocale$1({ + defaultLocale({ thousands: ",", grouping: [3], currency: ["$", ""] }); - function defaultLocale$1(definition) { - locale$1 = formatLocale$1(definition); - format = locale$1.format; - formatPrefix = locale$1.formatPrefix; - return locale$1; + function defaultLocale(definition) { + locale = formatLocale(definition); + format = locale.format; + formatPrefix = locale.formatPrefix; + return locale; } function precisionFixed(step) { @@ -4918,1055 +4628,6 @@ var drawChart = (function (exports) { return linearish(scale); } - const t0 = new Date, t1 = new Date; - - function timeInterval(floori, offseti, count, field) { - - function interval(date) { - return floori(date = arguments.length === 0 ? new Date : new Date(+date)), date; - } - - interval.floor = (date) => { - return floori(date = new Date(+date)), date; - }; - - interval.ceil = (date) => { - return floori(date = new Date(date - 1)), offseti(date, 1), floori(date), date; - }; - - interval.round = (date) => { - const d0 = interval(date), d1 = interval.ceil(date); - return date - d0 < d1 - date ? d0 : d1; - }; - - interval.offset = (date, step) => { - return offseti(date = new Date(+date), step == null ? 1 : Math.floor(step)), date; - }; - - interval.range = (start, stop, step) => { - const range = []; - start = interval.ceil(start); - step = step == null ? 1 : Math.floor(step); - if (!(start < stop) || !(step > 0)) return range; // also handles Invalid Date - let previous; - do range.push(previous = new Date(+start)), offseti(start, step), floori(start); - while (previous < start && start < stop); - return range; - }; - - interval.filter = (test) => { - return timeInterval((date) => { - if (date >= date) while (floori(date), !test(date)) date.setTime(date - 1); - }, (date, step) => { - if (date >= date) { - if (step < 0) while (++step <= 0) { - while (offseti(date, -1), !test(date)) {} // eslint-disable-line no-empty - } else while (--step >= 0) { - while (offseti(date, +1), !test(date)) {} // eslint-disable-line no-empty - } - } - }); - }; - - if (count) { - interval.count = (start, end) => { - t0.setTime(+start), t1.setTime(+end); - floori(t0), floori(t1); - return Math.floor(count(t0, t1)); - }; - - interval.every = (step) => { - step = Math.floor(step); - return !isFinite(step) || !(step > 0) ? null - : !(step > 1) ? interval - : interval.filter(field - ? (d) => field(d) % step === 0 - : (d) => interval.count(0, d) % step === 0); - }; - } - - return interval; - } - - const millisecond = timeInterval(() => { - // noop - }, (date, step) => { - date.setTime(+date + step); - }, (start, end) => { - return end - start; - }); - - // An optimized implementation for this simple case. - millisecond.every = (k) => { - k = Math.floor(k); - if (!isFinite(k) || !(k > 0)) return null; - if (!(k > 1)) return millisecond; - return timeInterval((date) => { - date.setTime(Math.floor(date / k) * k); - }, (date, step) => { - date.setTime(+date + step * k); - }, (start, end) => { - return (end - start) / k; - }); - }; - - millisecond.range; - - const durationSecond = 1000; - const durationMinute = durationSecond * 60; - const durationHour = durationMinute * 60; - const durationDay = durationHour * 24; - const durationWeek = durationDay * 7; - - const second = timeInterval((date) => { - date.setTime(date - date.getMilliseconds()); - }, (date, step) => { - date.setTime(+date + step * durationSecond); - }, (start, end) => { - return (end - start) / durationSecond; - }, (date) => { - return date.getUTCSeconds(); - }); - - second.range; - - const timeMinute = timeInterval((date) => { - date.setTime(date - date.getMilliseconds() - date.getSeconds() * durationSecond); - }, (date, step) => { - date.setTime(+date + step * durationMinute); - }, (start, end) => { - return (end - start) / durationMinute; - }, (date) => { - return date.getMinutes(); - }); - - timeMinute.range; - - const utcMinute = timeInterval((date) => { - date.setUTCSeconds(0, 0); - }, (date, step) => { - date.setTime(+date + step * durationMinute); - }, (start, end) => { - return (end - start) / durationMinute; - }, (date) => { - return date.getUTCMinutes(); - }); - - utcMinute.range; - - const timeHour = timeInterval((date) => { - date.setTime(date - date.getMilliseconds() - date.getSeconds() * durationSecond - date.getMinutes() * durationMinute); - }, (date, step) => { - date.setTime(+date + step * durationHour); - }, (start, end) => { - return (end - start) / durationHour; - }, (date) => { - return date.getHours(); - }); - - timeHour.range; - - const utcHour = timeInterval((date) => { - date.setUTCMinutes(0, 0, 0); - }, (date, step) => { - date.setTime(+date + step * durationHour); - }, (start, end) => { - return (end - start) / durationHour; - }, (date) => { - return date.getUTCHours(); - }); - - utcHour.range; - - const timeDay = timeInterval( - date => date.setHours(0, 0, 0, 0), - (date, step) => date.setDate(date.getDate() + step), - (start, end) => (end - start - (end.getTimezoneOffset() - start.getTimezoneOffset()) * durationMinute) / durationDay, - date => date.getDate() - 1 - ); - - timeDay.range; - - const utcDay = timeInterval((date) => { - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step); - }, (start, end) => { - return (end - start) / durationDay; - }, (date) => { - return date.getUTCDate() - 1; - }); - - utcDay.range; - - const unixDay = timeInterval((date) => { - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step); - }, (start, end) => { - return (end - start) / durationDay; - }, (date) => { - return Math.floor(date / durationDay); - }); - - unixDay.range; - - function timeWeekday(i) { - return timeInterval((date) => { - date.setDate(date.getDate() - (date.getDay() + 7 - i) % 7); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setDate(date.getDate() + step * 7); - }, (start, end) => { - return (end - start - (end.getTimezoneOffset() - start.getTimezoneOffset()) * durationMinute) / durationWeek; - }); - } - - const timeSunday = timeWeekday(0); - const timeMonday = timeWeekday(1); - const timeTuesday = timeWeekday(2); - const timeWednesday = timeWeekday(3); - const timeThursday = timeWeekday(4); - const timeFriday = timeWeekday(5); - const timeSaturday = timeWeekday(6); - - timeSunday.range; - timeMonday.range; - timeTuesday.range; - timeWednesday.range; - timeThursday.range; - timeFriday.range; - timeSaturday.range; - - function utcWeekday(i) { - return timeInterval((date) => { - date.setUTCDate(date.getUTCDate() - (date.getUTCDay() + 7 - i) % 7); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step * 7); - }, (start, end) => { - return (end - start) / durationWeek; - }); - } - - const utcSunday = utcWeekday(0); - const utcMonday = utcWeekday(1); - const utcTuesday = utcWeekday(2); - const utcWednesday = utcWeekday(3); - const utcThursday = utcWeekday(4); - const utcFriday = utcWeekday(5); - const utcSaturday = utcWeekday(6); - - utcSunday.range; - utcMonday.range; - utcTuesday.range; - utcWednesday.range; - utcThursday.range; - utcFriday.range; - utcSaturday.range; - - const timeMonth = timeInterval((date) => { - date.setDate(1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setMonth(date.getMonth() + step); - }, (start, end) => { - return end.getMonth() - start.getMonth() + (end.getFullYear() - start.getFullYear()) * 12; - }, (date) => { - return date.getMonth(); - }); - - timeMonth.range; - - const utcMonth = timeInterval((date) => { - date.setUTCDate(1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCMonth(date.getUTCMonth() + step); - }, (start, end) => { - return end.getUTCMonth() - start.getUTCMonth() + (end.getUTCFullYear() - start.getUTCFullYear()) * 12; - }, (date) => { - return date.getUTCMonth(); - }); - - utcMonth.range; - - const timeYear = timeInterval((date) => { - date.setMonth(0, 1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setFullYear(date.getFullYear() + step); - }, (start, end) => { - return end.getFullYear() - start.getFullYear(); - }, (date) => { - return date.getFullYear(); - }); - - // An optimized implementation for this simple case. - timeYear.every = (k) => { - return !isFinite(k = Math.floor(k)) || !(k > 0) ? null : timeInterval((date) => { - date.setFullYear(Math.floor(date.getFullYear() / k) * k); - date.setMonth(0, 1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setFullYear(date.getFullYear() + step * k); - }); - }; - - timeYear.range; - - const utcYear = timeInterval((date) => { - date.setUTCMonth(0, 1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCFullYear(date.getUTCFullYear() + step); - }, (start, end) => { - return end.getUTCFullYear() - start.getUTCFullYear(); - }, (date) => { - return date.getUTCFullYear(); - }); - - // An optimized implementation for this simple case. - utcYear.every = (k) => { - return !isFinite(k = Math.floor(k)) || !(k > 0) ? null : timeInterval((date) => { - date.setUTCFullYear(Math.floor(date.getUTCFullYear() / k) * k); - date.setUTCMonth(0, 1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCFullYear(date.getUTCFullYear() + step * k); - }); - }; - - utcYear.range; - - function localDate(d) { - if (0 <= d.y && d.y < 100) { - var date = new Date(-1, d.m, d.d, d.H, d.M, d.S, d.L); - date.setFullYear(d.y); - return date; - } - return new Date(d.y, d.m, d.d, d.H, d.M, d.S, d.L); - } - - function utcDate(d) { - if (0 <= d.y && d.y < 100) { - var date = new Date(Date.UTC(-1, d.m, d.d, d.H, d.M, d.S, d.L)); - date.setUTCFullYear(d.y); - return date; - } - return new Date(Date.UTC(d.y, d.m, d.d, d.H, d.M, d.S, d.L)); - } - - function newDate(y, m, d) { - return {y: y, m: m, d: d, H: 0, M: 0, S: 0, L: 0}; - } - - function formatLocale(locale) { - var locale_dateTime = locale.dateTime, - locale_date = locale.date, - locale_time = locale.time, - locale_periods = locale.periods, - locale_weekdays = locale.days, - locale_shortWeekdays = locale.shortDays, - locale_months = locale.months, - locale_shortMonths = locale.shortMonths; - - var periodRe = formatRe(locale_periods), - periodLookup = formatLookup(locale_periods), - weekdayRe = formatRe(locale_weekdays), - weekdayLookup = formatLookup(locale_weekdays), - shortWeekdayRe = formatRe(locale_shortWeekdays), - shortWeekdayLookup = formatLookup(locale_shortWeekdays), - monthRe = formatRe(locale_months), - monthLookup = formatLookup(locale_months), - shortMonthRe = formatRe(locale_shortMonths), - shortMonthLookup = formatLookup(locale_shortMonths); - - var formats = { - "a": formatShortWeekday, - "A": formatWeekday, - "b": formatShortMonth, - "B": formatMonth, - "c": null, - "d": formatDayOfMonth, - "e": formatDayOfMonth, - "f": formatMicroseconds, - "g": formatYearISO, - "G": formatFullYearISO, - "H": formatHour24, - "I": formatHour12, - "j": formatDayOfYear, - "L": formatMilliseconds, - "m": formatMonthNumber, - "M": formatMinutes, - "p": formatPeriod, - "q": formatQuarter, - "Q": formatUnixTimestamp, - "s": formatUnixTimestampSeconds, - "S": formatSeconds, - "u": formatWeekdayNumberMonday, - "U": formatWeekNumberSunday, - "V": formatWeekNumberISO, - "w": formatWeekdayNumberSunday, - "W": formatWeekNumberMonday, - "x": null, - "X": null, - "y": formatYear, - "Y": formatFullYear, - "Z": formatZone, - "%": formatLiteralPercent - }; - - var utcFormats = { - "a": formatUTCShortWeekday, - "A": formatUTCWeekday, - "b": formatUTCShortMonth, - "B": formatUTCMonth, - "c": null, - "d": formatUTCDayOfMonth, - "e": formatUTCDayOfMonth, - "f": formatUTCMicroseconds, - "g": formatUTCYearISO, - "G": formatUTCFullYearISO, - "H": formatUTCHour24, - "I": formatUTCHour12, - "j": formatUTCDayOfYear, - "L": formatUTCMilliseconds, - "m": formatUTCMonthNumber, - "M": formatUTCMinutes, - "p": formatUTCPeriod, - "q": formatUTCQuarter, - "Q": formatUnixTimestamp, - "s": formatUnixTimestampSeconds, - "S": formatUTCSeconds, - "u": formatUTCWeekdayNumberMonday, - "U": formatUTCWeekNumberSunday, - "V": formatUTCWeekNumberISO, - "w": formatUTCWeekdayNumberSunday, - "W": formatUTCWeekNumberMonday, - "x": null, - "X": null, - "y": formatUTCYear, - "Y": formatUTCFullYear, - "Z": formatUTCZone, - "%": formatLiteralPercent - }; - - var parses = { - "a": parseShortWeekday, - "A": parseWeekday, - "b": parseShortMonth, - "B": parseMonth, - "c": parseLocaleDateTime, - "d": parseDayOfMonth, - "e": parseDayOfMonth, - "f": parseMicroseconds, - "g": parseYear, - "G": parseFullYear, - "H": parseHour24, - "I": parseHour24, - "j": parseDayOfYear, - "L": parseMilliseconds, - "m": parseMonthNumber, - "M": parseMinutes, - "p": parsePeriod, - "q": parseQuarter, - "Q": parseUnixTimestamp, - "s": parseUnixTimestampSeconds, - "S": parseSeconds, - "u": parseWeekdayNumberMonday, - "U": parseWeekNumberSunday, - "V": parseWeekNumberISO, - "w": parseWeekdayNumberSunday, - "W": parseWeekNumberMonday, - "x": parseLocaleDate, - "X": parseLocaleTime, - "y": parseYear, - "Y": parseFullYear, - "Z": parseZone, - "%": parseLiteralPercent - }; - - // These recursive directive definitions must be deferred. - formats.x = newFormat(locale_date, formats); - formats.X = newFormat(locale_time, formats); - formats.c = newFormat(locale_dateTime, formats); - utcFormats.x = newFormat(locale_date, utcFormats); - utcFormats.X = newFormat(locale_time, utcFormats); - utcFormats.c = newFormat(locale_dateTime, utcFormats); - - function newFormat(specifier, formats) { - return function(date) { - var string = [], - i = -1, - j = 0, - n = specifier.length, - c, - pad, - format; - - if (!(date instanceof Date)) date = new Date(+date); - - while (++i < n) { - if (specifier.charCodeAt(i) === 37) { - string.push(specifier.slice(j, i)); - if ((pad = pads[c = specifier.charAt(++i)]) != null) c = specifier.charAt(++i); - else pad = c === "e" ? " " : "0"; - if (format = formats[c]) c = format(date, pad); - string.push(c); - j = i + 1; - } - } - - string.push(specifier.slice(j, i)); - return string.join(""); - }; - } - - function newParse(specifier, Z) { - return function(string) { - var d = newDate(1900, undefined, 1), - i = parseSpecifier(d, specifier, string += "", 0), - week, day; - if (i != string.length) return null; - - // If a UNIX timestamp is specified, return it. - if ("Q" in d) return new Date(d.Q); - if ("s" in d) return new Date(d.s * 1000 + ("L" in d ? d.L : 0)); - - // If this is utcParse, never use the local timezone. - if (Z && !("Z" in d)) d.Z = 0; - - // The am-pm flag is 0 for AM, and 1 for PM. - if ("p" in d) d.H = d.H % 12 + d.p * 12; - - // If the month was not specified, inherit from the quarter. - if (d.m === undefined) d.m = "q" in d ? d.q : 0; - - // Convert day-of-week and week-of-year to day-of-year. - if ("V" in d) { - if (d.V < 1 || d.V > 53) return null; - if (!("w" in d)) d.w = 1; - if ("Z" in d) { - week = utcDate(newDate(d.y, 0, 1)), day = week.getUTCDay(); - week = day > 4 || day === 0 ? utcMonday.ceil(week) : utcMonday(week); - week = utcDay.offset(week, (d.V - 1) * 7); - d.y = week.getUTCFullYear(); - d.m = week.getUTCMonth(); - d.d = week.getUTCDate() + (d.w + 6) % 7; - } else { - week = localDate(newDate(d.y, 0, 1)), day = week.getDay(); - week = day > 4 || day === 0 ? timeMonday.ceil(week) : timeMonday(week); - week = timeDay.offset(week, (d.V - 1) * 7); - d.y = week.getFullYear(); - d.m = week.getMonth(); - d.d = week.getDate() + (d.w + 6) % 7; - } - } else if ("W" in d || "U" in d) { - if (!("w" in d)) d.w = "u" in d ? d.u % 7 : "W" in d ? 1 : 0; - day = "Z" in d ? utcDate(newDate(d.y, 0, 1)).getUTCDay() : localDate(newDate(d.y, 0, 1)).getDay(); - d.m = 0; - d.d = "W" in d ? (d.w + 6) % 7 + d.W * 7 - (day + 5) % 7 : d.w + d.U * 7 - (day + 6) % 7; - } - - // If a time zone is specified, all fields are interpreted as UTC and then - // offset according to the specified time zone. - if ("Z" in d) { - d.H += d.Z / 100 | 0; - d.M += d.Z % 100; - return utcDate(d); - } - - // Otherwise, all fields are in local time. - return localDate(d); - }; - } - - function parseSpecifier(d, specifier, string, j) { - var i = 0, - n = specifier.length, - m = string.length, - c, - parse; - - while (i < n) { - if (j >= m) return -1; - c = specifier.charCodeAt(i++); - if (c === 37) { - c = specifier.charAt(i++); - parse = parses[c in pads ? specifier.charAt(i++) : c]; - if (!parse || ((j = parse(d, string, j)) < 0)) return -1; - } else if (c != string.charCodeAt(j++)) { - return -1; - } - } - - return j; - } - - function parsePeriod(d, string, i) { - var n = periodRe.exec(string.slice(i)); - return n ? (d.p = periodLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseShortWeekday(d, string, i) { - var n = shortWeekdayRe.exec(string.slice(i)); - return n ? (d.w = shortWeekdayLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseWeekday(d, string, i) { - var n = weekdayRe.exec(string.slice(i)); - return n ? (d.w = weekdayLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseShortMonth(d, string, i) { - var n = shortMonthRe.exec(string.slice(i)); - return n ? (d.m = shortMonthLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseMonth(d, string, i) { - var n = monthRe.exec(string.slice(i)); - return n ? (d.m = monthLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseLocaleDateTime(d, string, i) { - return parseSpecifier(d, locale_dateTime, string, i); - } - - function parseLocaleDate(d, string, i) { - return parseSpecifier(d, locale_date, string, i); - } - - function parseLocaleTime(d, string, i) { - return parseSpecifier(d, locale_time, string, i); - } - - function formatShortWeekday(d) { - return locale_shortWeekdays[d.getDay()]; - } - - function formatWeekday(d) { - return locale_weekdays[d.getDay()]; - } - - function formatShortMonth(d) { - return locale_shortMonths[d.getMonth()]; - } - - function formatMonth(d) { - return locale_months[d.getMonth()]; - } - - function formatPeriod(d) { - return locale_periods[+(d.getHours() >= 12)]; - } - - function formatQuarter(d) { - return 1 + ~~(d.getMonth() / 3); - } - - function formatUTCShortWeekday(d) { - return locale_shortWeekdays[d.getUTCDay()]; - } - - function formatUTCWeekday(d) { - return locale_weekdays[d.getUTCDay()]; - } - - function formatUTCShortMonth(d) { - return locale_shortMonths[d.getUTCMonth()]; - } - - function formatUTCMonth(d) { - return locale_months[d.getUTCMonth()]; - } - - function formatUTCPeriod(d) { - return locale_periods[+(d.getUTCHours() >= 12)]; - } - - function formatUTCQuarter(d) { - return 1 + ~~(d.getUTCMonth() / 3); - } - - return { - format: function(specifier) { - var f = newFormat(specifier += "", formats); - f.toString = function() { return specifier; }; - return f; - }, - parse: function(specifier) { - var p = newParse(specifier += "", false); - p.toString = function() { return specifier; }; - return p; - }, - utcFormat: function(specifier) { - var f = newFormat(specifier += "", utcFormats); - f.toString = function() { return specifier; }; - return f; - }, - utcParse: function(specifier) { - var p = newParse(specifier += "", true); - p.toString = function() { return specifier; }; - return p; - } - }; - } - - var pads = {"-": "", "_": " ", "0": "0"}, - numberRe = /^\s*\d+/, // note: ignores next directive - percentRe = /^%/, - requoteRe = /[\\^$*+?|[\]().{}]/g; - - function pad(value, fill, width) { - var sign = value < 0 ? "-" : "", - string = (sign ? -value : value) + "", - length = string.length; - return sign + (length < width ? new Array(width - length + 1).join(fill) + string : string); - } - - function requote(s) { - return s.replace(requoteRe, "\\$&"); - } - - function formatRe(names) { - return new RegExp("^(?:" + names.map(requote).join("|") + ")", "i"); - } - - function formatLookup(names) { - return new Map(names.map((name, i) => [name.toLowerCase(), i])); - } - - function parseWeekdayNumberSunday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.w = +n[0], i + n[0].length) : -1; - } - - function parseWeekdayNumberMonday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.u = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberSunday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.U = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberISO(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.V = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberMonday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.W = +n[0], i + n[0].length) : -1; - } - - function parseFullYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 4)); - return n ? (d.y = +n[0], i + n[0].length) : -1; - } - - function parseYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.y = +n[0] + (+n[0] > 68 ? 1900 : 2000), i + n[0].length) : -1; - } - - function parseZone(d, string, i) { - var n = /^(Z)|([+-]\d\d)(?::?(\d\d))?/.exec(string.slice(i, i + 6)); - return n ? (d.Z = n[1] ? 0 : -(n[2] + (n[3] || "00")), i + n[0].length) : -1; - } - - function parseQuarter(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.q = n[0] * 3 - 3, i + n[0].length) : -1; - } - - function parseMonthNumber(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.m = n[0] - 1, i + n[0].length) : -1; - } - - function parseDayOfMonth(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.d = +n[0], i + n[0].length) : -1; - } - - function parseDayOfYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 3)); - return n ? (d.m = 0, d.d = +n[0], i + n[0].length) : -1; - } - - function parseHour24(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.H = +n[0], i + n[0].length) : -1; - } - - function parseMinutes(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.M = +n[0], i + n[0].length) : -1; - } - - function parseSeconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.S = +n[0], i + n[0].length) : -1; - } - - function parseMilliseconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 3)); - return n ? (d.L = +n[0], i + n[0].length) : -1; - } - - function parseMicroseconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 6)); - return n ? (d.L = Math.floor(n[0] / 1000), i + n[0].length) : -1; - } - - function parseLiteralPercent(d, string, i) { - var n = percentRe.exec(string.slice(i, i + 1)); - return n ? i + n[0].length : -1; - } - - function parseUnixTimestamp(d, string, i) { - var n = numberRe.exec(string.slice(i)); - return n ? (d.Q = +n[0], i + n[0].length) : -1; - } - - function parseUnixTimestampSeconds(d, string, i) { - var n = numberRe.exec(string.slice(i)); - return n ? (d.s = +n[0], i + n[0].length) : -1; - } - - function formatDayOfMonth(d, p) { - return pad(d.getDate(), p, 2); - } - - function formatHour24(d, p) { - return pad(d.getHours(), p, 2); - } - - function formatHour12(d, p) { - return pad(d.getHours() % 12 || 12, p, 2); - } - - function formatDayOfYear(d, p) { - return pad(1 + timeDay.count(timeYear(d), d), p, 3); - } - - function formatMilliseconds(d, p) { - return pad(d.getMilliseconds(), p, 3); - } - - function formatMicroseconds(d, p) { - return formatMilliseconds(d, p) + "000"; - } - - function formatMonthNumber(d, p) { - return pad(d.getMonth() + 1, p, 2); - } - - function formatMinutes(d, p) { - return pad(d.getMinutes(), p, 2); - } - - function formatSeconds(d, p) { - return pad(d.getSeconds(), p, 2); - } - - function formatWeekdayNumberMonday(d) { - var day = d.getDay(); - return day === 0 ? 7 : day; - } - - function formatWeekNumberSunday(d, p) { - return pad(timeSunday.count(timeYear(d) - 1, d), p, 2); - } - - function dISO(d) { - var day = d.getDay(); - return (day >= 4 || day === 0) ? timeThursday(d) : timeThursday.ceil(d); - } - - function formatWeekNumberISO(d, p) { - d = dISO(d); - return pad(timeThursday.count(timeYear(d), d) + (timeYear(d).getDay() === 4), p, 2); - } - - function formatWeekdayNumberSunday(d) { - return d.getDay(); - } - - function formatWeekNumberMonday(d, p) { - return pad(timeMonday.count(timeYear(d) - 1, d), p, 2); - } - - function formatYear(d, p) { - return pad(d.getFullYear() % 100, p, 2); - } - - function formatYearISO(d, p) { - d = dISO(d); - return pad(d.getFullYear() % 100, p, 2); - } - - function formatFullYear(d, p) { - return pad(d.getFullYear() % 10000, p, 4); - } - - function formatFullYearISO(d, p) { - var day = d.getDay(); - d = (day >= 4 || day === 0) ? timeThursday(d) : timeThursday.ceil(d); - return pad(d.getFullYear() % 10000, p, 4); - } - - function formatZone(d) { - var z = d.getTimezoneOffset(); - return (z > 0 ? "-" : (z *= -1, "+")) - + pad(z / 60 | 0, "0", 2) - + pad(z % 60, "0", 2); - } - - function formatUTCDayOfMonth(d, p) { - return pad(d.getUTCDate(), p, 2); - } - - function formatUTCHour24(d, p) { - return pad(d.getUTCHours(), p, 2); - } - - function formatUTCHour12(d, p) { - return pad(d.getUTCHours() % 12 || 12, p, 2); - } - - function formatUTCDayOfYear(d, p) { - return pad(1 + utcDay.count(utcYear(d), d), p, 3); - } - - function formatUTCMilliseconds(d, p) { - return pad(d.getUTCMilliseconds(), p, 3); - } - - function formatUTCMicroseconds(d, p) { - return formatUTCMilliseconds(d, p) + "000"; - } - - function formatUTCMonthNumber(d, p) { - return pad(d.getUTCMonth() + 1, p, 2); - } - - function formatUTCMinutes(d, p) { - return pad(d.getUTCMinutes(), p, 2); - } - - function formatUTCSeconds(d, p) { - return pad(d.getUTCSeconds(), p, 2); - } - - function formatUTCWeekdayNumberMonday(d) { - var dow = d.getUTCDay(); - return dow === 0 ? 7 : dow; - } - - function formatUTCWeekNumberSunday(d, p) { - return pad(utcSunday.count(utcYear(d) - 1, d), p, 2); - } - - function UTCdISO(d) { - var day = d.getUTCDay(); - return (day >= 4 || day === 0) ? utcThursday(d) : utcThursday.ceil(d); - } - - function formatUTCWeekNumberISO(d, p) { - d = UTCdISO(d); - return pad(utcThursday.count(utcYear(d), d) + (utcYear(d).getUTCDay() === 4), p, 2); - } - - function formatUTCWeekdayNumberSunday(d) { - return d.getUTCDay(); - } - - function formatUTCWeekNumberMonday(d, p) { - return pad(utcMonday.count(utcYear(d) - 1, d), p, 2); - } - - function formatUTCYear(d, p) { - return pad(d.getUTCFullYear() % 100, p, 2); - } - - function formatUTCYearISO(d, p) { - d = UTCdISO(d); - return pad(d.getUTCFullYear() % 100, p, 2); - } - - function formatUTCFullYear(d, p) { - return pad(d.getUTCFullYear() % 10000, p, 4); - } - - function formatUTCFullYearISO(d, p) { - var day = d.getUTCDay(); - d = (day >= 4 || day === 0) ? utcThursday(d) : utcThursday.ceil(d); - return pad(d.getUTCFullYear() % 10000, p, 4); - } - - function formatUTCZone() { - return "+0000"; - } - - function formatLiteralPercent() { - return "%"; - } - - function formatUnixTimestamp(d) { - return +d; - } - - function formatUnixTimestampSeconds(d) { - return Math.floor(+d / 1000); - } - - var locale; - var utcFormat; - var utcParse; - - defaultLocale({ - dateTime: "%x, %X", - date: "%-m/%-d/%Y", - time: "%-I:%M:%S %p", - periods: ["AM", "PM"], - days: ["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"], - shortDays: ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"], - months: ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"], - shortMonths: ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] - }); - - function defaultLocale(definition) { - locale = formatLocale(definition); - locale.format; - locale.parse; - utcFormat = locale.utcFormat; - utcParse = locale.utcParse; - return locale; - } - - var isoSpecifier = "%Y-%m-%dT%H:%M:%S.%LZ"; - - function formatIsoNative(date) { - return date.toISOString(); - } - - Date.prototype.toISOString - ? formatIsoNative - : utcFormat(isoSpecifier); - - function parseIsoNative(string) { - var date = new Date(string); - return isNaN(date) ? null : date; - } - - +new Date("2000-01-01T00:00:00.000Z") - ? parseIsoNative - : utcParse(isoSpecifier); - function transformer() { var x0 = 0, x1 = 1, @@ -6087,7 +4748,7 @@ var drawChart = (function (exports) { }; }; - const StaticContext = B$2({}); + const StaticContext = G({}); const drawChart = (parentNode, data, width, height) => { const availableSizeProperties = getAvailableSizeOptions(data.options); console.time("layout create"); @@ -6131,7 +4792,7 @@ var drawChart = (function (exports) { console.time("color"); const getModuleColor = createRainbowColor(rawHierarchy); console.timeEnd("color"); - P(o$1(StaticContext.Provider, Object.assign({ value: { + D(o$1(StaticContext.Provider, { value: { data, availableSizeProperties, width, @@ -6141,7 +4802,7 @@ var drawChart = (function (exports) { getModuleColor, rawHierarchy, layout, - } }, { children: o$1(Main, {}) })), parentNode); + }, children: o$1(Main, {}) }), parentNode); }; exports.StaticContext = StaticContext; @@ -6157,7 +4818,7 @@ var drawChart = (function (exports) {