Merge branch 'main' into doc_updates_23

This commit is contained in:
Kent Keirsey 2023-07-06 11:24:42 -04:00 committed by GitHub
commit b250d1ec86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
227 changed files with 8142 additions and 6571 deletions

2
.gitignore vendored
View File

@ -201,8 +201,6 @@ checkpoints
# If it's a Mac # If it's a Mac
.DS_Store .DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore # Let the frontend manage its own gitignore
!invokeai/frontend/web/* !invokeai/frontend/web/*

View File

@ -2,17 +2,17 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from fastapi import Query from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel): class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE") repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE") path: Optional[str] = Field(description="The path to the VAE")
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel): class ImportModelResponse(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import") name: str = Field(description="The name of the imported model")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files') # base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
@ -86,7 +89,6 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
@models_router.post( @models_router.post(
"/", "/",
operation_id="update_model", operation_id="update_model",
@ -109,27 +111,38 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/", "/import",
operation_id="import_model", operation_id="import_model",
responses={200: {"status": "success"}}, responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
},
status_code=201,
response_model=ImportModelResponse
) )
async def import_model( async def import_model(
model_request: ImportModelRequest name: str = Query(description="A model path, repo_id or URL to import"),
) -> None: prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
""" Add Model """ ) -> ImportModelResponse:
items_to_import = set([model_request.name]) """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
) )
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
else: else:
logger.error(f'Model {model_request.name} not imported') logger.error(f'Model {name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported') raise HTTPException(status_code=404, detail=f'Model {name} not found')
@models_router.delete( @models_router.delete(
"/{model_name}", "/{model_name}",

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseModel, Field from pydantic import BaseConfig, BaseModel, Field
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -65,7 +66,12 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_invocations_map(cls): def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary # Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)
@classmethod @classmethod
def get_output_type(cls): def get_output_type(cls):
@ -76,10 +82,10 @@ class BaseInvocation(ABC, BaseModel):
"""Invoke with provided context and return outputs.""" """Invoke with provided context and return outputs."""
pass pass
#fmt: off # fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on # fmt: on
# TODO: figure out a better way to provide these hints # TODO: figure out a better way to provide these hints
@ -97,16 +103,20 @@ class UIConfig(TypedDict, total=False):
"latents", "latents",
"model", "model",
"control", "control",
"image_collection",
"vae_model",
"lora_model",
], ],
] ]
tags: List[str] tags: List[str]
title: str title: str
class CustomisedSchemaExtra(TypedDict): class CustomisedSchemaExtra(TypedDict):
ui: UIConfig ui: UIConfig
class InvocationConfig(BaseModel.Config): class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations. """Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs. Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -4,13 +4,16 @@ from typing import Literal
import numpy as np import numpy as np
from pydantic import Field, validator from pydantic import Field, validator
from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationConfig,
InvocationContext, InvocationContext,
BaseInvocationOutput, BaseInvocationOutput,
UIConfig,
) )
@ -22,6 +25,7 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs # Outputs
collection: list[int] = Field(default=[], description="The int collection") collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput): class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats""" """A collection of floats"""
@ -31,6 +35,18 @@ class FloatCollectionOutput(BaseInvocationOutput):
collection: list[float] = Field(default=[], description="The float collection") collection: list[float] = Field(default=[], description="The float collection")
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
type: Literal["image_collection"] = "image_collection"
# Outputs
collection: list[ImageField] = Field(default=[], description="The output images")
class Config:
schema_extra = {"required": ["type", "collection"]}
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""
@ -92,3 +108,27 @@ class RandomRangeInvocation(BaseInvocation):
return IntCollectionOutput( return IntCollectionOutput(
collection=list(rng.integers(low=self.low, high=self.high, size=self.size)) collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
) )
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
# fmt: off
type: Literal["image_collection"] = "image_collection"
# Inputs
images: list[ImageField] = Field(
default=[], description="The image collection to load"
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"type_hints": {
"images": "image_collection",
}
},
}

View File

@ -1,27 +1,28 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import torch
from .model import ClipField from compel import Compel
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field
from ...backend.util.devices import torch_dtype from ...backend.model_management.models import ModelNotFoundException
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from compel import Compel from ...backend.util.devices import torch_dtype
from compel.prompt_parser import ( from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
Blend, InvocationConfig, InvocationContext)
CrossAttentionControlSubstitute, from .model import ClipField
FlattenedPrompt,
Fragment, Conjunction,
)
class ConditioningField(BaseModel): class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config: class Config:
schema_extra = {"required": ["conditioning_name"]} schema_extra = {"required": ["conditioning_name"]}
@ -56,18 +57,24 @@ class CompelInvocation(BaseInvocation):
}, },
} }
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder:
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
@ -80,14 +87,15 @@ class CompelInvocation(BaseInvocation):
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
) )
except Exception: except ModelNotFoundException:
#print(e) # print(e)
#import traceback #import traceback
#print(traceback.format_exc()) #print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found") print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
text_encoder_info as text_encoder:
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -103,15 +111,17 @@ class CompelInvocation(BaseInvocation):
if context.services.configuration.log_tokenization: if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support # TODO: long prompt support
#if not self.truncate_long_prompts: # if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), tokens_count_including_eos_bos=get_max_token_count(
cross_attention_control_args=options.get("cross_attention_control", None), tokenizer, conjunction),
) cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
@ -126,8 +136,8 @@ class CompelInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
) -> int: truncate_if_too_long=False) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max(
@ -146,13 +156,13 @@ def get_max_token_count(
) )
else: else:
return len( return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) get_tokens_for_prompt_object(
) tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object( def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]: ) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError( raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children" "Blend is not supported here - you need to get tokens for each of its .children"
@ -181,7 +191,7 @@ def log_tokenization_for_conjunction(
): ):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts)>1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
@ -236,7 +246,8 @@ def log_tokenization_for_prompt_object(
) )
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -4,18 +4,17 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents") latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off #fmt: off
@ -70,14 +72,17 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -134,8 +139,8 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState self, context: InvocationContext, source_node_id: str,
) -> None: intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: def get_conditioning_data(
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) self, context: InvocationContext, scheduler) -> ConditioningData:
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
extra=extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold, threshold=0.0, # threshold,
warmup=0.2,#warmup, warmup=0.2, # warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None # v_symmetry_time_pct,
), ),
) )
@ -164,20 +172,21 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, #ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0), generator=torch.Generator(device=uc.device).manual_seed(0),
) )
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
#configure_model_padding( # configure_model_padding(
# unet, # unet,
# self.seamless, # self.seamless,
# self.seamless_axes, # self.seamless_axes,
#) # )
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
@ -202,7 +211,8 @@ class TextToLatentsInvocation(BaseInvocation):
def prep_control_data( def prep_control_data(
self, self,
context: InvocationContext, context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device # really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
print("Using HF model subfolders") print("Using HF model subfolders")
print(" control_name: ", control_name) print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder) print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name, control_model = ControlNetModel.from_pretrained(
subfolder=control_subfolder, control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(
model.device)
else: else:
control_model = ControlNetModel.from_pretrained(control_info.control_model, control_model = ControlNetModel.from_pretrained(
torch_dtype=model.unet.dtype).to(model.device) control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name) input_image = context.services.images.get_pil_image(
control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -263,29 +275,40 @@ class TextToLatentsInvocation(BaseInvocation):
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(
image_tensor=control_image, model=control_model, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,)
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) def _lora_loader():
with unet_info as unet: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -296,8 +319,6 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
@ -305,7 +326,6 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
) )
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
}, },
} }
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model( def _lora_loader():
**self.unet.unet.dict(), for lora in self.unet.loras:
) lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
with unet_info as unet: unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype latent, device=unet.device, dtype=latent.dtype)
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -389,9 +421,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device, device=unet.device,
) )
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
timesteps=timesteps, timesteps=timesteps,
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
height=image_dto.height, height=image_dto.height,
) )
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
] "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") description="The latents to resize")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") width: int = Field(
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") ge=64, multiple_of=8, description="The width to resize to (px)")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") height: int = Field(
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, size=(self.height // 8, self.width // 8),
size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: Optional[LatentsField] = Field(
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") description="The latents to scale")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") scale_factor: float = Field(
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, scale_factor=self.scale_factor, mode=self.mode,
scale_factor=self.scale_factor, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The image to encode") image: Union[ImageField, None] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):

View File

@ -1,31 +1,38 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy import copy
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from pydantic import BaseModel, Field
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model") weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
@ -34,43 +41,48 @@ class VaeField(BaseModel):
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["model_loader_output"] = "model_loader_output" type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on # fmt: on
class PipelineModelField(BaseModel): class MainModelField(BaseModel):
"""Pipeline model field""" """Main model field"""
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation): class LoRAModelField(BaseModel):
"""Loads a pipeline model, outputting its submodels.""" """LoRA model field"""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader" model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
model: PipelineModelField = Field(description="The model to load")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision? # TODO: precision?
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Model Loader",
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {"model": "model"},
"model": "model"
}
}, },
} }
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
model_type = ModelType.Main model_type = ModelType.Main
@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation):
) )
""" """
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation):
model_type=model_type, model_type=model_type,
submodel=SubModelType.Vae, submodel=SubModelType.Vae,
), ),
),
) )
)
class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output" type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel") unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on # fmt: on
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader" type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name") lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora") weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora") clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
# TODO: ui rewrite def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
base_model = BaseModelType.StableDiffusion1 if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
): ):
raise Exception(f"Unkown lora name: {self.lora_name}!") raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): if self.unet is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): if self.clip is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput() output = LoraLoaderOutput()
@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
# fmt: on
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {"vae_model": "vae_model"},
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
)

View File

@ -367,7 +367,8 @@ setting environment variables INVOKEAI_<setting>.
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')

View File

@ -7,7 +7,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices from invokeai.app.services.restoration_services import RestorationServices
@ -22,46 +22,47 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"] boards: "BoardServiceABC"
configuration: "InvokeAISettings"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__( def __init__(
self, self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC", boards: "BoardServiceABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings", configuration: "InvokeAISettings",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
): ):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path = None) -> None: def commit(self, conf_file: Path = None) -> None:
""" """
@ -183,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(config,'max_cache_size') \ if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5 else config.max_loaded_models * 2.5
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance sequential_offload = config.sequential_guidance
self.mgr = ModelManager( self.mgr = ModelManager(
@ -361,3 +386,24 @@ class ModelManagerService(ModelManagerServiceBase):
def logger(self): def logger(self):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)

View File

@ -430,13 +430,13 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
max_height=len(PRECISION_CHOICES) + 1, max_height=len(PRECISION_CHOICES) + 1,
scroll_exit=True, scroll_exit=True,
) )
self.max_loaded_models = self.add_widget_intelligent( self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider, IntTitleSlider,
name="Number of models to cache in CPU memory (each will use 2-4 GB!)", name="Size of the RAM cache used for fast model switching (GB)",
value=old_opts.max_loaded_models, value=old_opts.max_cache_size,
out_of=10, out_of=20,
lowest=1, lowest=3,
begin_entry_at=4, begin_entry_at=6,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
@ -539,7 +539,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
"outdir", "outdir",
"nsfw_checker", "nsfw_checker",
"free_gpu_mem", "free_gpu_mem",
"max_loaded_models", "max_cache_size",
"xformers_enabled", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
]: ]:
@ -555,9 +555,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
new_opts.license_acceptance = self.license_acceptance.value new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
# widget library workaround to make max_loaded_models an int rather than a float
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
return new_opts return new_opts

View File

@ -4,6 +4,8 @@ import argparse
import shlex import shlex
from argparse import ArgumentParser from argparse import ArgumentParser
# note that this includes both old sampler names and new scheduler names
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"ddpm", "ddpm",
@ -27,6 +29,15 @@ SAMPLER_CHOICES = [
"dpmpp_sde", "dpmpp_sde",
"dpmpp_sde_k", "dpmpp_sde_k",
"unipc", "unipc",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
] ]
PRECISION_CHOICES = [ PRECISION_CHOICES = [

View File

@ -76,6 +76,10 @@ class MigrateTo3(object):
Create a unique name for a model for use within models.yaml. Create a unique name for a model for use within models.yaml.
''' '''
done = False done = False
# some model names have slashes in them, which really screws things up
name = name.replace('/','_')
key = ModelManager.create_key(name,info.base_type,info.model_type) key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key unique_name = key
counter = 1 counter = 1
@ -219,11 +223,12 @@ class MigrateTo3(object):
repo_id = 'openai/clip-vit-large-patch14' repo_id = 'openai/clip-vit-large-patch14'
self._migrate_pretrained(CLIPTokenizer, self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id, repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', dest= target_dir / 'clip-vit-large-patch14',
**kwargs) **kwargs)
self._migrate_pretrained(CLIPTextModel, self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id, repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', dest = target_dir / 'clip-vit-large-patch14',
force = True,
**kwargs) **kwargs)
# sd-2 # sd-2
@ -287,18 +292,18 @@ class MigrateTo3(object):
def _model_probe_to_path(self, info: ModelProbeInfo)->Path: def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value) return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
if dest.exists(): if dest.exists() and not force:
logger.info(f'Skipping existing {dest}') logger.info(f'Skipping existing {dest}')
return return
model = model_class.from_pretrained(repo_id, **kwargs) model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest) self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path): def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name model_name = dest.name
if overwrite:
model.save_pretrained(dest, safe_serialization=True)
else:
download_path = dest.with_name(f'{model_name}.downloading') download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True) model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest) download_path.replace(dest)
@ -569,8 +574,10 @@ script, which will perform a full upgrade in place."""
dest_directory = args.dest_directory dest_directory = args.dest_directory
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." # TODO: revisit
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
do_migrate(root_directory,dest_directory) do_migrate(root_directory,dest_directory)

View File

@ -18,7 +18,7 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
@ -166,17 +166,22 @@ class ModelInstall(object):
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path) self.heuristic_import(path)
job += 1 job += 1
self.mgr.commit() self.mgr.commit()
def heuristic_install(self, def heuristic_import(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]: models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
'''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
if not models_installed: if not models_installed:
models_installed = set() models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID # A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
@ -185,24 +190,27 @@ class ModelInstall(object):
try: try:
# checkpoint file, or similar # checkpoint file, or similar
if path.is_file(): if path.is_file():
models_installed.add(self._install_path(path)) models_installed.update(self._install_path(path))
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): elif path.is_dir() and any([(path/x).exists() for x in \
models_installed.add(self._install_path(path)) {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():
for child in path.iterdir(): for child in path.iterdir():
self.heuristic_install(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(path).split('/')) == 2: elif len(str(path).split('/')) == 2:
models_installed.add(self._install_repo(str(path))) models_installed.update(self._install_repo(str(path)))
# a URL # a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.add(self._install_url(model_path_id_or_url)) models_installed.update(self._install_url(model_path_id_or_url))
else: else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
@ -214,24 +222,25 @@ class ModelInstall(object):
# install a model from a local path. The optional info parameter is there to prevent # install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
try: try:
# logger.debug(f'Probing {path}') model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info) attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = model_name, model_result = self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
model_type = info.model_type, model_type = info.model_type,
model_attributes = attributes, model_attributes = attributes,
) )
except Exception as e: except Exception as e:
logger.warning(f'{str(e)} Skipping registration.') logger.warning(f'{str(e)} Skipping registration.')
return path return {}
return {str(path): model_result}
def _install_url(self, url: str)->Path: def _install_url(self, url: str)->dict:
# copy to a staging area, probe, import and delete # copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
@ -244,7 +253,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info) return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->Path: def _install_repo(self, repo_id: str)->dict:
hinfo = HfApi().model_info(repo_id) hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically # we try to figure out how to download this most economically

View File

@ -1,7 +1,7 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
try: try:
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
else: else:
if scan_needed: if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path) # scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not

View File

@ -1,18 +1,15 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, List
import torch import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
class LoRALayerBase: class LoRALayerBase:
#rank: Optional[int] #rank: Optional[int]
@ -124,8 +121,8 @@ class LoRALayer(LoRALayerBase):
def get_weight(self): def get_weight(self):
if self.mid is not None: if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1]) up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else: else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
@ -411,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
else: else:
# TODO: diff/ia3/... format # TODO: diff/ia3/... format
print( print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}" f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
) )
return return
@ -539,7 +536,8 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu # enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"): #with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale layer_weight = layer.get_weight() * lora_weight * layer_scale
@ -655,6 +653,9 @@ class TextualInversionModel:
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -8,7 +8,7 @@ The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the model into the GPU within the context, and unload outside the
context. Use like this: context. Use like this:
cache = ModelCache(max_models_cached=6) cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2: cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2) do_something_in_GPU(SD1,SD2)
@ -91,7 +91,7 @@ class ModelCache(object):
logger: types.ModuleType = logger logger: types.ModuleType = logger
): ):
''' '''
:param max_models: Maximum number of models to cache in CPU RAM [4] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
@ -100,8 +100,6 @@ class ModelCache(object):
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
#max_cache_size = 9999 #max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload #self.sequential_offload: bool=sequential_offload
@ -128,16 +126,6 @@ class ModelCache(object):
key += f":{submodel_type}" key += f":{submodel_type}"
return key return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: ModelType = ModelType.Diffusers,
# subfolder: Path = None,
# submodel: ModelType = None,
# revision: str = None,
# attach_model_part: Tuple[ModelType, str] = (None, None),
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info( def _get_model_info(
self, self,
model_path: str, model_path: str,

View File

@ -233,14 +233,14 @@ import hashlib
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union, Set, Callable, types from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree from shutil import rmtree
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelConfigBase, ModelNotFoundException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
"Raised when an invalid model is requested" "Raised when an invalid model is requested"
pass pass
MAX_CACHE_SIZE = 6.0 # GB class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
MAX_CACHE_SIZE = 6.0 # GB
class ConfigMeta(BaseModel): class ConfigMeta(BaseModel):
version: str version: str
@ -404,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models: if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type) self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models: if model_key not in self.models:
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
@ -416,14 +421,14 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type) override_path = getattr(model_config, submodel_type)
if override_path: if override_path:
model_path = override_path model_path = self.app_config.root_path / override_path
model_type = submodel_type model_type = submodel_type
submodel_type = None submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@ -431,6 +436,7 @@ class ModelManager(object):
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
dst_convert_path = self._get_model_cache_path(model_path) dst_convert_path = self._get_model_cache_path(model_path)
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
base_model=base_model, base_model=base_model,
model_path=str(model_path), # TODO: refactor str/Path types logic model_path=str(model_path), # TODO: refactor str/Path types logic
@ -570,13 +576,16 @@ class ModelManager(object):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
) -> None: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and the On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
The returned dict has the same format as the dict returned by
model_info().
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@ -600,12 +609,18 @@ class ModelManager(object):
old_model_cache.unlink() old_model_cache.unlink()
# remove in-memory cache # remove in-memory cache
# note: it not garantie to release memory(model can has other references) # note: it not guaranteed to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids: for cache_id in cache_ids:
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config self.models[model_key] = model_config
return AddModelResult(
name = model_name,
model_type = model_type,
base_model = base_model,
config = model_config,
)
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
@ -728,7 +743,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->set[Path]: def autoimport(self)->Dict[str, AddModelResult]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
''' '''
@ -741,7 +756,6 @@ class ModelManager(object):
prediction_type_helper = ask_user_for_prediction_type, prediction_type_helper = ask_user_for_prediction_type,
) )
installed = set()
scanned_dirs = set() scanned_dirs = set()
config = self.app_config config = self.app_config
@ -755,13 +769,14 @@ class ModelManager(object):
continue continue
self.logger.info(f'Scanning {autodir} for models to import') self.logger.info(f'Scanning {autodir} for models to import')
installed = dict()
autodir = self.app_config.root_path / autodir autodir = self.app_config.root_path / autodir
if not autodir.exists(): if not autodir.exists():
continue continue
items_scanned = 0 items_scanned = 0
new_models_found = set() new_models_found = dict()
for root, dirs, files in os.walk(autodir): for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files) items_scanned += len(dirs) + len(files)
@ -770,8 +785,8 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs: if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path) scanned_dirs.add(path)
continue continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
new_models_found.update(installer.heuristic_install(path)) new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path) scanned_dirs.add(path)
for f in files: for f in files:
@ -779,7 +794,8 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs: if path in known_paths or path.parent in scanned_dirs:
continue continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_install(path)) import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found) installed.update(new_models_found)
@ -789,7 +805,7 @@ class ModelManager(object):
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]: )->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -802,17 +818,20 @@ class ModelManager(object):
generally impossible to do this programmatically, so the generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose. prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
''' '''
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set() successfully_installed = dict()
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self)
for thing in items_to_import: for thing in items_to_import:
try: try:
installed = installer.heuristic_install(thing) installed = installer.heuristic_import(thing)
successfully_installed.update(installed) successfully_installed.update(installed)
except Exception as e: except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}') self.logger.warning(f'{thing} could not be imported: {str(e)}')

View File

@ -78,7 +78,6 @@ class ModelProbe(object):
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint' format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
else: else:
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
@ -105,7 +104,7 @@ class ModelProbe(object):
) else 512, ) else 512,
) )
except Exception: except Exception:
return None raise
return model_info return model_info
@ -127,6 +126,8 @@ class ModelProbe(object):
return ModelType.Vae return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
@ -137,7 +138,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion return ModelType.TextualInversion
raise ValueError("Unable to determine model type") raise ValueError(f"Unable to determine model type for {model_path}")
@classmethod @classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -167,7 +168,7 @@ class ModelProbe(object):
return type return type
# give up # give up
raise ValueError("Unable to determine model type") raise ValueError("Unable to determine model type for {folder_path}")
@classmethod @classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict: def _scan_and_load_checkpoint(cls,model_path: Path)->dict:

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path, file_path=checkpoint_path,
dtype=torch_dtype, dtype=torch_dtype,
) )

View File

@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
# this is where the TUI is called # this is where the TUI is called
else: else:
# needed because the torch library is loaded, even though we don't use it # needed to support the probe() method running under a subprocess
# currently commented out because it has started generating errors (?) torch.multiprocessing.set_start_method("spawn")
# torch.multiprocessing.set_start_method("spawn")
# the third argument is needed in the Windows 11 environment in # the third argument is needed in the Windows 11 environment in
# order to launch and resize a console window running this program # order to launch and resize a console window running this program

View File

@ -36,6 +36,12 @@ module.exports = {
], ],
'prettier/prettier': ['error', { endOfLine: 'auto' }], 'prettier/prettier': ['error', { endOfLine: 'auto' }],
'@typescript-eslint/ban-ts-comment': 'warn', '@typescript-eslint/ban-ts-comment': 'warn',
'@typescript-eslint/no-empty-interface': [
'error',
{
allowSingleExtends: true,
},
],
}, },
settings: { settings: {
react: { react: {

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script> <script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -24,16 +24,13 @@
}, },
"common": { "common": {
"hotkeysLabel": "Hotkeys", "hotkeysLabel": "Hotkeys",
"themeLabel": "Theme", "darkMode": "Dark Mode",
"lightMode": "Light Mode",
"languagePickerLabel": "Language", "languagePickerLabel": "Language",
"reportBugLabel": "Report Bug", "reportBugLabel": "Report Bug",
"githubLabel": "Github", "githubLabel": "Github",
"discordLabel": "Discord", "discordLabel": "Discord",
"settingsLabel": "Settings", "settingsLabel": "Settings",
"darkTheme": "Dark",
"lightTheme": "Light",
"greenTheme": "Green",
"oceanTheme": "Ocean",
"langArabic": "العربية", "langArabic": "العربية",
"langEnglish": "English", "langEnglish": "English",
"langDutch": "Nederlands", "langDutch": "Nederlands",
@ -55,6 +52,7 @@
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",
@ -336,6 +334,7 @@
"modelManager": { "modelManager": {
"modelManager": "Model Manager", "modelManager": "Model Manager",
"model": "Model", "model": "Model",
"vae": "VAE",
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
@ -351,6 +350,7 @@
"scanForModels": "Scan For Models", "scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"baseModel": "Base Model",
"name": "Name", "name": "Name",
"nameValidationMsg": "Enter a name for your model", "nameValidationMsg": "Enter a name for your model",
"description": "Description", "description": "Description",
@ -363,6 +363,7 @@
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.", "vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID", "vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE", "vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width", "width": "Width",
@ -524,7 +525,8 @@
"initialImage": "Initial Image", "initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel", "showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview" "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
@ -547,7 +549,8 @@
"general": "General", "general": "General",
"generation": "Generation", "generation": "Generation",
"ui": "User Interface", "ui": "User Interface",
"availableSchedulers": "Available Schedulers" "favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
}, },
"toast": { "toast": {
"serverError": "Server Error", "serverError": "Server Error",

View File

@ -67,6 +67,7 @@
"@fontsource-variable/inter": "^5.0.3", "@fontsource-variable/inter": "^5.0.3",
"@fontsource/inter": "^5.0.3", "@fontsource/inter": "^5.0.3",
"@mantine/core": "^6.0.14", "@mantine/core": "^6.0.14",
"@mantine/form": "^6.0.15",
"@mantine/hooks": "^6.0.14", "@mantine/hooks": "^6.0.14",
"@reduxjs/toolkit": "^1.9.5", "@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5", "@roarr/browser-log-writer": "^1.1.5",
@ -82,7 +83,7 @@
"konva": "^9.2.0", "konva": "^9.2.0",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.9.2", "nanostores": "^0.9.2",
"openapi-fetch": "^0.4.0", "openapi-fetch": "0.4.0",
"overlayscrollbars": "^2.2.0", "overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0", "overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0", "patch-package": "^7.0.0",

View File

@ -52,6 +52,8 @@
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"batch": "Batch Manager",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",
@ -333,6 +335,7 @@
"modelManager": { "modelManager": {
"modelManager": "Model Manager", "modelManager": "Model Manager",
"model": "Model", "model": "Model",
"vae": "VAE",
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
@ -348,6 +351,7 @@
"scanForModels": "Scan For Models", "scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"baseModel": "Base Model",
"name": "Name", "name": "Name",
"nameValidationMsg": "Enter a name for your model", "nameValidationMsg": "Enter a name for your model",
"description": "Description", "description": "Description",
@ -360,6 +364,7 @@
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.", "vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID", "vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE", "vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width", "width": "Width",

View File

@ -1,67 +1,40 @@
import { Box, Flex, Grid, Portal } from '@chakra-ui/react'; import { Flex, Grid, Portal } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger'; import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import Loading from 'common/components/Loading/Loading';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import InvokeTabs from 'features/ui/components/InvokeTabs'; import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import { AnimatePresence, motion } from 'framer-motion';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo, useCallback, useEffect, useState } from 'react'; import { ReactNode, memo, useEffect } from 'react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/api/endpoints/models';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
interface Props { interface Props {
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode; headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
} }
const App = ({ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
config = DEFAULT_CONFIG,
headerComponent,
setIsReady,
}: Props) => {
const language = useAppSelector(languageSelector); const language = useAppSelector(languageSelector);
const log = useLogger(); const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'main',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useEffect(() => { useEffect(() => {
@ -73,27 +46,6 @@ const App = ({
dispatch(configChanged(config)); dispatch(configChanged(config));
}, [dispatch, config, log]); }, [dispatch, config, log]);
const handleOverrideClicked = useCallback(() => {
setLoadingOverridden(true);
}, []);
useEffect(() => {
if (isApplicationReady && setIsReady) {
setIsReady(true);
}
if (isApplicationReady) {
// TODO: This is a jank fix for canvas not filling the screen on first load
setTimeout(() => {
dispatch(requestCanvasRescale());
}, 200);
}
return () => {
setIsReady && setIsReady(false);
};
}, [dispatch, isApplicationReady, setIsReady]);
return ( return (
<> <>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden"> <Grid w="100vw" h="100vh" position="relative" overflow="hidden">
@ -123,33 +75,6 @@ const App = ({
<GalleryDrawer /> <GalleryDrawer />
<ParametersDrawer /> <ParametersDrawer />
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
key="loading"
initial={{ opacity: 1 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.3 }}
style={{ zIndex: 3 }}
>
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
<Loading />
</Box>
<Box
onClick={handleOverrideClicked}
position="absolute"
top={0}
right={0}
cursor="pointer"
w="2rem"
h="2rem"
/>
</motion.div>
)}
</AnimatePresence>
<Portal> <Portal>
<FloatingParametersPanelButtons /> <FloatingParametersPanelButtons />
</Portal> </Portal>

View File

@ -0,0 +1,122 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd';
type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null;
};
const BOX_SIZE = 28;
const STYLES: ChakraProps['sx'] = {
w: BOX_SIZE,
h: BOX_SIZE,
maxW: BOX_SIZE,
maxH: BOX_SIZE,
shadow: 'dark-lg',
borderRadius: 'lg',
borderWidth: 2,
borderStyle: 'dashed',
borderColor: 'base.100',
opacity: 0.5,
bg: 'base.800',
color: 'base.50',
_dark: {
borderColor: 'base.200',
bg: 'base.900',
color: 'base.100',
},
};
const selector = createSelector(
stateSelector,
(state) => {
const gallerySelectionCount = state.gallery.selection.length;
const batchSelectionCount = state.batch.selection.length;
return {
gallerySelectionCount,
batchSelectionCount,
};
},
defaultSelectorOptions
);
const DragPreview = (props: OverlayDragImageProps) => {
const { gallerySelectionCount, batchSelectionCount } =
useAppSelector(selector);
if (!props.dragData) {
return;
}
if (props.dragData.payloadType === 'IMAGE_DTO') {
return (
<Box
sx={{
position: 'relative',
width: '100%',
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'none',
}}
>
<Image
sx={{
...STYLES,
}}
src={props.dragData.payload.imageDTO.thumbnail_url}
/>
</Box>
);
}
if (props.dragData.payloadType === 'BATCH_SELECTION') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{batchSelectionCount}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{gallerySelectionCount}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
return null;
};
export default memo(DragPreview);

View File

@ -1,8 +1,5 @@
import { import {
DndContext,
DragEndEvent,
DragOverlay, DragOverlay,
DragStartEvent,
MouseSensor, MouseSensor,
TouchSensor, TouchSensor,
pointerWithin, pointerWithin,
@ -10,33 +7,45 @@ import {
useSensors, useSensors,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { PropsWithChildren, memo, useCallback, useState } from 'react'; import { PropsWithChildren, memo, useCallback, useState } from 'react';
import OverlayDragImage from './OverlayDragImage'; import DragPreview from './DragPreview';
import { ImageDTO } from 'services/api/types';
import { isImageDTO } from 'services/api/guards';
import { snapCenterToCursor } from '@dnd-kit/modifiers'; import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';
import {
DndContext,
DragEndEvent,
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
type ImageDndContextProps = PropsWithChildren; type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => { const ImageDndContext = (props: ImageDndContextProps) => {
const [draggedImage, setDraggedImage] = useState<ImageDTO | null>(null); const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const dispatch = useAppDispatch();
const handleDragStart = useCallback((event: DragStartEvent) => { const handleDragStart = useCallback((event: DragStartEvent) => {
const dragData = event.active.data.current; const activeData = event.active.data.current;
if (dragData && 'image' in dragData && isImageDTO(dragData.image)) { if (!activeData) {
setDraggedImage(dragData.image); return;
} }
setActiveDragData(activeData);
}, []); }, []);
const handleDragEnd = useCallback( const handleDragEnd = useCallback(
(event: DragEndEvent) => { (event: DragEndEvent) => {
const handleDrop = event.over?.data.current?.handleDrop; const activeData = event.active.data.current;
if (handleDrop && typeof handleDrop === 'function' && draggedImage) { const overData = event.over?.data.current;
handleDrop(draggedImage); if (!activeData || !overData) {
return;
} }
setDraggedImage(null); dispatch(imageDropped({ overData, activeData }));
setActiveDragData(null);
}, },
[draggedImage] [dispatch]
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {
@ -46,6 +55,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const touchSensor = useSensor(TouchSensor, { const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 150, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
// (currently the drag element collision rect is not correctly calculated) // (currently the drag element collision rect is not correctly calculated)
@ -63,7 +73,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
{props.children} {props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}> <DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
<AnimatePresence> <AnimatePresence>
{draggedImage && ( {activeDragData && (
<motion.div <motion.div
layout layout
key="overlay-drag-image" key="overlay-drag-image"
@ -77,7 +87,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
transition: { duration: 0.1 }, transition: { duration: 0.1 },
}} }}
> >
<OverlayDragImage image={draggedImage} /> <DragPreview dragData={activeDragData} />
</motion.div> </motion.div>
)} )}
</AnimatePresence> </AnimatePresence>

View File

@ -1,36 +0,0 @@
import { Box, Image } from '@chakra-ui/react';
import { memo } from 'react';
import { ImageDTO } from 'services/api/types';
type OverlayDragImageProps = {
image: ImageDTO;
};
const OverlayDragImage = (props: OverlayDragImageProps) => {
return (
<Box
style={{
width: '100%',
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'grabbing',
opacity: 0.5,
}}
>
<Image
sx={{
maxW: 36,
maxH: 36,
borderRadius: 'base',
shadow: 'dark-lg',
}}
src={props.image.thumbnail_url}
/>
</Box>
);
};
export default memo(OverlayDragImage);

View File

@ -0,0 +1,201 @@
// type-safe dnd from https://github.com/clauderic/dnd-kit/issues/935
import {
Active,
Collision,
DndContextProps,
DndContext as OriginalDndContext,
Over,
Translate,
UseDraggableArguments,
UseDroppableArguments,
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
id: string;
};
export type CurrentImageDropData = BaseDropData & {
actionType: 'SET_CURRENT_IMAGE';
};
export type InitialImageDropData = BaseDropData & {
actionType: 'SET_INITIAL_IMAGE';
};
export type ControlNetDropData = BaseDropData & {
actionType: 'SET_CONTROLNET_IMAGE';
context: {
controlNetId: string;
};
};
export type CanvasInitialImageDropData = BaseDropData & {
actionType: 'SET_CANVAS_INITIAL_IMAGE';
};
export type NodesImageDropData = BaseDropData & {
actionType: 'SET_NODES_IMAGE';
context: {
nodeId: string;
fieldName: string;
};
};
export type NodesMultiImageDropData = BaseDropData & {
actionType: 'SET_MULTI_NODES_IMAGE';
context: { nodeId: string; fieldName: string };
};
export type AddToBatchDropData = BaseDropData & {
actionType: 'ADD_TO_BATCH';
};
export type MoveBoardDropData = BaseDropData & {
actionType: 'MOVE_BOARD';
context: { boardId: string | null };
};
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
| ControlNetDropData
| CanvasInitialImageDropData
| NodesImageDropData
| AddToBatchDropData
| NodesMultiImageDropData
| MoveBoardDropData;
type BaseDragData = {
id: string;
};
export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO };
};
export type GallerySelectionDraggableData = BaseDragData & {
payloadType: 'GALLERY_SELECTION';
};
export type BatchSelectionDraggableData = BaseDragData & {
payloadType: 'BATCH_SELECTION';
};
export type TypesafeDraggableData =
| ImageDraggableData
| GallerySelectionDraggableData
| BatchSelectionDraggableData;
interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
data?: TypesafeDroppableData;
}
type UseDroppableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDroppable>,
'active' | 'over'
> & {
active: TypesafeActive | null;
over: TypesafeOver | null;
};
export function useDroppable(props: UseDroppableTypesafeArguments) {
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
}
interface UseDraggableTypesafeArguments
extends Omit<UseDraggableArguments, 'data'> {
data?: TypesafeDraggableData;
}
type UseDraggableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDraggable>,
'active' | 'over'
> & {
active: TypesafeActive | null;
over: TypesafeOver | null;
};
export function useDraggable(props: UseDraggableTypesafeArguments) {
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
}
interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}
interface TypesafeOver extends Omit<Over, 'data'> {
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
}
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'MOVE_BOARD':
return (
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
);
default:
return false;
}
};
interface DragEvent {
activatorEvent: Event;
active: TypesafeActive;
collisions: Collision[] | null;
delta: Translate;
over: TypesafeOver | null;
}
export interface DragStartEvent extends Pick<DragEvent, 'active'> {}
export interface DragMoveEvent extends DragEvent {}
export interface DragOverEvent extends DragMoveEvent {}
export interface DragEndEvent extends DragEvent {}
export interface DragCancelEvent extends DragEndEvent {}
export interface DndContextTypesafeProps
extends Omit<
DndContextProps,
'onDragStart' | 'onDragMove' | 'onDragOver' | 'onDragEnd' | 'onDragCancel'
> {
onDragStart?(event: DragStartEvent): void;
onDragMove?(event: DragMoveEvent): void;
onDragOver?(event: DragOverEvent): void;
onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void;
}
export function DndContext(props: DndContextTypesafeProps) {
return <OriginalDndContext {...props} />;
}

View File

@ -7,7 +7,6 @@ import React, {
} from 'react'; } from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { store } from 'app/store/store'; import { store } from 'app/store/store';
// import { OpenAPI } from 'services/api/types';
import Loading from '../../common/components/Loading/Loading'; import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@ -17,11 +16,6 @@ import '../../i18n';
import { socketMiddleware } from 'services/events/middleware'; import { socketMiddleware } from 'services/events/middleware';
import { Middleware } from '@reduxjs/toolkit'; import { Middleware } from '@reduxjs/toolkit';
import ImageDndContext from './ImageDnd/ImageDndContext'; import ImageDndContext from './ImageDnd/ImageDndContext';
import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import { $authToken, $baseUrl } from 'services/api/client'; import { $authToken, $baseUrl } from 'services/api/client';
import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext'; import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext';
@ -34,7 +28,6 @@ interface Props extends PropsWithChildren {
token?: string; token?: string;
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode; headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
middleware?: Middleware[]; middleware?: Middleware[];
} }
@ -43,7 +36,6 @@ const InvokeAIUI = ({
token, token,
config, config,
headerComponent, headerComponent,
setIsReady,
middleware, middleware,
}: Props) => { }: Props) => {
useEffect(() => { useEffect(() => {
@ -85,17 +77,11 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<ImageDndContext> <ImageDndContext>
<DeleteImageContextProvider>
<AddImageToBoardContextProvider> <AddImageToBoardContextProvider>
<DeleteBoardImagesContextProvider> <DeleteBoardImagesContextProvider>
<App <App config={config} headerComponent={headerComponent} />
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</DeleteBoardImagesContextProvider> </DeleteBoardImagesContextProvider>
</AddImageToBoardContextProvider> </AddImageToBoardContextProvider>
</DeleteImageContextProvider>
</ImageDndContext> </ImageDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>

View File

@ -5,15 +5,15 @@ import { useDeleteBoardMutation } from '../../services/api/endpoints/boards';
import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
import { canvasSelector } from '../../features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from '../../features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { selectImagesById } from '../../features/gallery/store/imagesSlice'; import { selectImagesById } from 'features/gallery/store/gallerySlice';
import { nodesSelector } from '../../features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from '../../features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { RootState } from '../store/store'; import { RootState } from '../store/store';
import { useAppDispatch, useAppSelector } from '../store/storeHooks'; import { useAppDispatch, useAppSelector } from '../store/storeHooks';
import { ImageUsage } from './DeleteImageContext'; import { ImageUsage } from './DeleteImageContext';
import { requestedBoardImagesDeletion } from '../../features/gallery/store/actions'; import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
export const selectBoardImagesUsage = createSelector( export const selectBoardImagesUsage = createSelector(
[ [

View File

@ -1,201 +0,0 @@
import { useDisclosure } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
PropsWithChildren,
createContext,
useCallback,
useEffect,
useState,
} from 'react';
import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
export const selectImageUsage = createSelector(
[
generationSelector,
canvasSelector,
nodesSelector,
controlNetSelector,
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
};
return imageUsage;
},
defaultSelectorOptions
);
type DeleteImageContextValue = {
/**
* Whether the delete image dialog is open.
*/
isOpen: boolean;
/**
* Closes the delete image dialog.
*/
onClose: () => void;
/**
* Opens the delete image dialog and handles all deletion-related checks.
*/
onDelete: (image?: ImageDTO) => void;
/**
* The image pending deletion
*/
image?: ImageDTO;
/**
* The features in which this image is used
*/
imageUsage?: ImageUsage;
/**
* Immediately deletes an image.
*
* You probably don't want to use this - use `onDelete` instead.
*/
onImmediatelyDelete: () => void;
};
export const DeleteImageContext = createContext<DeleteImageContextValue>({
isOpen: false,
onClose: () => undefined,
onImmediatelyDelete: () => undefined,
onDelete: () => undefined,
});
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
};
},
defaultSelectorOptions
);
type Props = PropsWithChildren;
export const DeleteImageContextProvider = (props: Props) => {
const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector);
const [imageToDelete, setImageToDelete] = useState<ImageDTO>();
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
// Check where the image to be deleted is used (eg init image, controlnet, etc.)
const imageUsage = useAppSelector((state) =>
selectImageUsage(state, imageToDelete?.image_name)
);
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToDelete(undefined);
onClose();
}, [onClose]);
// Dispatch the actual deletion action, to be handled by listener middleware
const handleActualDeletion = useCallback(
(image: ImageDTO) => {
dispatch(requestedImageDeletion({ image, imageUsage }));
closeAndClearImageToDelete();
},
[closeAndClearImageToDelete, dispatch, imageUsage]
);
// This is intended to be called by the delete button in the dialog
const onImmediatelyDelete = useCallback(() => {
if (canDeleteImage && imageToDelete) {
handleActualDeletion(imageToDelete);
}
closeAndClearImageToDelete();
}, [
canDeleteImage,
imageToDelete,
closeAndClearImageToDelete,
handleActualDeletion,
]);
const handleGatedDeletion = useCallback(
(image: ImageDTO) => {
if (shouldConfirmOnDelete || some(imageUsage)) {
// If we should confirm on delete, or if the image is in use, open the dialog
onOpen();
} else {
handleActualDeletion(image);
}
},
[imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion]
);
// Consumers of the context call this to delete an image
const onDelete = useCallback((image?: ImageDTO) => {
if (!image) {
return;
}
// Set the image to delete, then let the effect call the actual deletion
setImageToDelete(image);
}, []);
useEffect(() => {
// We need to use an effect here to trigger the image usage selector, else we get a stale value
if (imageToDelete) {
handleGatedDeletion(imageToDelete);
}
}, [handleGatedDeletion, imageToDelete]);
return (
<DeleteImageContext.Provider
value={{
isOpen,
image: imageToDelete,
onClose: closeAndClearImageToDelete,
onDelete,
onImmediatelyDelete,
imageUsage,
}}
>
{props.children}
</DeleteImageContext.Provider>
);
};

View File

@ -20,10 +20,8 @@ const serializationDenylist: {
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist, ui: uiPersistDenylist,
controlNet: controlNetDenylist, controlNet: controlNetDenylist,
// hotkeys: hotkeysPersistDenylist,
}; };
export const serialize: SerializeFunction = (data, key) => { export const serialize: SerializeFunction = (data, key) => {

View File

@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice'; import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
@ -26,7 +25,6 @@ const initialStates: {
config: initialConfigState, config: initialConfigState,
ui: initialUIState, ui: initialUIState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
images: initialImagesState,
controlNet: initialControlNetState, controlNet: initialControlNetState,
}; };

View File

@ -72,7 +72,6 @@ import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingA
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
import { import {
addImageAddedToBoardFulfilledListener, addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener, addImageAddedToBoardRejectedListener,
@ -84,6 +83,9 @@ import {
} from './listeners/imageRemovedFromBoard'; } from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted'; import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { addImageDroppedListener } from './listeners/imageDropped';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -126,6 +128,7 @@ addImageDeletedPendingListener();
addImageDeletedFulfilledListener(); addImageDeletedFulfilledListener();
addImageDeletedRejectedListener(); addImageDeletedRejectedListener();
addRequestedBoardImageDeletionListener(); addRequestedBoardImageDeletionListener();
addImageToDeleteSelectedListener();
// Image metadata // Image metadata
addImageMetadataReceivedFulfilledListener(); addImageMetadataReceivedFulfilledListener();
@ -211,3 +214,9 @@ addBoardIdSelectedListener();
// Node schemas // Node schemas
addReceivedOpenAPISchemaListener(); addReceivedOpenAPISchemaListener();
// Batches
addSelectionAddedToBatchListener();
// DND
addImageDroppedListener();

View File

@ -1,12 +1,14 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { boardIdSelected } from 'features/gallery/store/boardSlice'; import {
import { selectImagesAll } from 'features/gallery/store/imagesSlice'; imageSelected,
selectImagesAll,
boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { import {
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
receivedPageOfImages, receivedPageOfImages,
} from 'services/api/thunks/image'; } from 'services/api/thunks/image';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -28,7 +30,7 @@ export const addBoardIdSelectedListener = () => {
return; return;
} }
const { categories } = state.images; const { categories } = state.gallery;
const filteredImages = allImages.filter((i) => { const filteredImages = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category); const isInCategory = categories.includes(i.image_category);
@ -47,7 +49,7 @@ export const addBoardIdSelectedListener = () => {
return; return;
} }
dispatch(imageSelected(board.cover_image_name)); dispatch(imageSelected(board.cover_image_name ?? null));
// if we haven't loaded one full page of images from this board, load more // if we haven't loaded one full page of images from this board, load more
if ( if (
@ -77,7 +79,7 @@ export const addBoardIdSelected_changeSelectedImage_listener = () => {
return; return;
} }
const { categories } = state.images; const { categories } = state.gallery;
const filteredImages = selectImagesAll(state).filter((i) => { const filteredImages = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category); const isInCategory = categories.includes(i.image_category);

View File

@ -1,11 +1,11 @@
import { requestedBoardImagesDeletion } from 'features/gallery/store/actions'; import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
imageSelected,
imagesRemoved, imagesRemoved,
selectImagesAll, selectImagesAll,
selectImagesById, selectImagesById,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
@ -22,12 +22,15 @@ export const addRequestedBoardImageDeletionListener = () => {
const { board_id } = board; const { board_id } = board;
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage const selectedImageName =
? selectImagesById(state, state.gallery.selectedImage) state.gallery.selection[state.gallery.selection.length - 1];
const selectedImage = selectedImageName
? selectImagesById(state, selectedImageName)
: undefined; : undefined;
if (selectedImage && selectedImage.board_id === board_id) { if (selectedImage && selectedImage.board_id === board_id) {
dispatch(imageSelected()); dispatch(imageSelected(null));
} }
// We need to reset the features where the board images are in use - none of these work if their image(s) don't exist // We need to reset the features where the board images are in use - none of these work if their image(s) don't exist

View File

@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });

View File

@ -3,8 +3,8 @@ import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { import {
imageCategoriesChanged, imageCategoriesChanged,
selectFilteredImagesAsArray, selectFilteredImages,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'gallery' }); const moduleLog = log.child({ namespace: 'gallery' });
@ -13,7 +13,7 @@ export const addImageCategoriesChangedListener = () => {
actionCreator: imageCategoriesChanged, actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const filteredImagesCount = selectFilteredImagesAsArray(state).length; const filteredImagesCount = selectFilteredImages(state).length;
if (!filteredImagesCount) { if (!filteredImagesCount) {
dispatch( dispatch(

View File

@ -1,18 +1,21 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..';
import { imageDeleted } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
imageRemoved,
selectImagesIds,
} from 'features/gallery/store/imagesSlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import {
imageRemoved,
imageSelected,
selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
import {
imageDeletionConfirmed,
isModalOpenChanged,
} from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api'; import { api } from 'services/api';
import { imageDeleted } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -21,17 +24,22 @@ const moduleLog = log.child({ namespace: 'image' });
*/ */
export const addRequestedImageDeletionListener = () => { export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState, condition }) => { effect: async (action, { dispatch, getState, condition }) => {
const { image, imageUsage } = action.payload; const { imageDTO, imageUsage } = action.payload;
const { image_name } = image; dispatch(isModalOpenChanged(false));
const { image_name } = imageDTO;
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage; const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (selectedImage === image_name) { if (lastSelectedImage === image_name) {
const ids = selectImagesIds(state); const filteredImages = selectFilteredImages(state);
const ids = filteredImages.map((i) => i.image_name);
const deletedImageIndex = ids.findIndex( const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name
@ -50,7 +58,7 @@ export const addRequestedImageDeletionListener = () => {
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageId as string)); dispatch(imageSelected(newSelectedImageId as string));
} else { } else {
dispatch(imageSelected()); dispatch(imageSelected(null));
} }
} }
@ -88,7 +96,7 @@ export const addRequestedImageDeletionListener = () => {
if (wasImageDeleted) { if (wasImageDeleted) {
dispatch( dispatch(
api.util.invalidateTags([{ type: 'Board', id: image.board_id }]) api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }])
); );
} }
}, },

View File

@ -0,0 +1,188 @@
import { createAction } from '@reduxjs/toolkit';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { log } from 'app/logging/useLogger';
import {
imageAddedToBatch,
imagesAddedToBatch,
} from 'features/batch/store/batchSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
fieldValueChanged,
imageCollectionFieldValueChanged,
} from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { startAppListening } from '../';
const moduleLog = log.child({ namespace: 'dnd' });
export const imageDropped = createAction<{
overData: TypesafeDroppableData;
activeData: TypesafeDraggableData;
}>('dnd/imageDropped');
export const addImageDroppedListener = () => {
startAppListening({
actionCreator: imageDropped,
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
const { actionType } = overData;
const state = getState();
// set current image
if (
actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
}
// set initial image
if (
actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(initialImageChanged(activeData.payload.imageDTO));
}
// add image to batch
if (
actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageAddedToBatch(activeData.payload.imageDTO.image_name));
}
// add multiple images to batch
if (
actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'GALLERY_SELECTION'
) {
dispatch(imagesAddedToBatch(state.gallery.selection));
}
// set control image
if (
actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { controlNetId } = overData.context;
dispatch(
controlNetImageChanged({
controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
})
);
}
// set canvas image
if (
actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(setInitialCanvasImage(activeData.payload.imageDTO));
}
// set nodes image
if (
actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
})
);
}
// set multiple nodes images (single image handler)
if (
actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: [activeData.payload.imageDTO],
})
);
}
// set multiple nodes images (multiple images handler)
if (
actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'GALLERY_SELECTION'
) {
const { fieldName, nodeId } = overData.context;
dispatch(
imageCollectionFieldValueChanged({
nodeId,
fieldName,
value: state.gallery.selection.map((image_name) => ({
image_name,
})),
})
);
}
// remove image from board
// TODO: remove board_id from `removeImageFromBoard()` endpoint
// TODO: handle multiple images
// if (
// actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO &&
// overData.boardId !== null
// ) {
// const { image_name } = activeData.payload.imageDTO;
// dispatch(
// boardImagesApi.endpoints.removeImageFromBoard.initiate({ image_name })
// );
// }
// add image to board
if (
actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO &&
overData.context.boardId
) {
const { image_name } = activeData.payload.imageDTO;
const { boardId } = overData.context;
dispatch(
boardImagesApi.endpoints.addImageToBoard.initiate({
image_name,
board_id: boardId,
})
);
}
// add multiple images to board
// TODO: add endpoint
// if (
// actionType === 'ADD_TO_BATCH' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// activeData.payload.imageDTONames
// ) {
// dispatch(boardImagesApi.endpoints.addImagesToBoard.intiate({}));
// }
},
});
};

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image'; import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });

View File

@ -0,0 +1,40 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import {
imageDeletionConfirmed,
imageToDeleteSelected,
isModalOpenChanged,
selectImageUsage,
} from 'features/imageDeletion/store/imageDeletionSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageToDeleteSelectedListener = () => {
startAppListening({
actionCreator: imageToDeleteSelected,
effect: async (action, { dispatch, getState, condition }) => {
const imageDTO = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imageUsage = selectImageUsage(getState());
if (!imageUsage) {
// should never happen
return;
}
const isImageInUse =
imageUsage.isCanvasImage ||
imageUsage.isInitialImage ||
imageUsage.isControlNetImage ||
imageUsage.isNodesImage;
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
},
});
};

View File

@ -2,11 +2,12 @@ import { startAppListening } from '..';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { imageAddedToBatch } from 'features/batch/store/batchSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -70,6 +71,11 @@ export const addImageUploadedFulfilledListener = () => {
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
return; return;
} }
if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imageAddedToBatch(image.image_name));
return;
}
}, },
}); });
}; };

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
import { imageUpdatedOne } from 'features/gallery/store/imagesSlice'; import { imageUpdatedOne } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });

View File

@ -4,7 +4,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { selectImagesById } from 'features/gallery/store/imagesSlice'; import { selectImagesById } from 'features/gallery/store/gallerySlice';
import { isImageDTO } from 'services/api/guards'; import { isImageDTO } from 'services/api/guards';
export const addInitialImageSelectedListener = () => { export const addInitialImageSelectedListener = () => {

View File

@ -2,6 +2,7 @@ import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { serializeError } from 'serialize-error'; import { serializeError } from 'serialize-error';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { imagesApi } from 'services/api/endpoints/images';
const moduleLog = log.child({ namespace: 'gallery' }); const moduleLog = log.child({ namespace: 'gallery' });
@ -9,11 +10,17 @@ export const addReceivedPageOfImagesFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedPageOfImages.fulfilled, actionCreator: receivedPageOfImages.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const page = action.payload; const { items } = action.payload;
moduleLog.debug( moduleLog.debug(
{ data: { payload: action.payload } }, { data: { payload: action.payload } },
`Received ${page.items.length} images` `Received ${items.length} images`
); );
items.forEach((image) => {
dispatch(
imagesApi.util.upsertQueryData('getImageDTO', image.image_name, image)
);
});
}, },
}); });
}; };

View File

@ -0,0 +1,19 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
const moduleLog = log.child({ namespace: 'batch' });
export const addSelectionAddedToBatchListener = () => {
startAppListening({
actionCreator: selectionAddedToBatch,
effect: (action, { dispatch, getState }) => {
const { selection } = getState().gallery;
dispatch(imagesAddedToBatch(selection));
},
});
};

View File

@ -1,6 +1,5 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { nodes, config, images } = getState(); const { nodes, config } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
if (!images.ids.length) {
dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
})
);
}
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }

View File

@ -2,7 +2,7 @@ import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpdated } from 'services/api/thunks/image'; import { imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvas' }); const moduleLog = log.child({ namespace: 'canvas' });

View File

@ -8,7 +8,7 @@ import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es'; import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { selectImagesEntities } from 'features/gallery/store/imagesSlice'; import { selectImagesEntities } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'images' }); const moduleLog = log.child({ namespace: 'images' });
@ -36,7 +36,7 @@ const selectAllUsedImages = createSelector(
nodes.nodes.forEach((node) => { nodes.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => { forEach(node.data.inputs, (input) => {
if (input.type === 'image' && input.value) { if (input.type === 'image' && input.value) {
allUsedImages.push(input.value); allUsedImages.push(input.value.image_name);
} }
}); });
}); });

View File

@ -8,31 +8,32 @@ import {
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { rememberEnhancer, rememberReducer } from 'redux-remember';
import batchReducer from 'features/batch/store/batchSlice';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice'; import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { api } from 'services/api';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { LOCALSTORAGE_PREFIX } from './constants'; import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { api } from 'services/api'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
const allReducers = { const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
@ -45,11 +46,12 @@ const allReducers = {
config: configReducer, config: configReducer,
ui: uiReducer, ui: uiReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer,
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
// session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer,
imageDeletion: imageDeletionReducer,
lora: loraReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -68,6 +70,8 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'batch',
'lora',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',

View File

@ -15,10 +15,25 @@ export interface IAIButtonProps extends ButtonProps {
} }
const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => { const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => {
const { children, tooltip = '', tooltipProps, isChecked, ...rest } = props; const {
children,
tooltip = '',
tooltipProps: { placement = 'top', hasArrow = true, ...tooltipProps } = {},
isChecked,
...rest
} = props;
return ( return (
<Tooltip label={tooltip} {...tooltipProps}> <Tooltip
<Button ref={forwardedRef} aria-checked={isChecked} {...rest}> label={tooltip}
placement={placement}
hasArrow={hasArrow}
{...tooltipProps}
>
<Button
ref={forwardedRef}
colorScheme={isChecked ? 'accent' : 'base'}
{...rest}
>
{children} {children}
</Button> </Button>
</Tooltip> </Tooltip>

View File

@ -4,22 +4,25 @@ import {
Collapse, Collapse,
Flex, Flex,
Spacer, Spacer,
Switch, Text,
useColorMode, useColorMode,
useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo } from 'react'; import { PropsWithChildren, memo } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
export type IAIToggleCollapseProps = PropsWithChildren & { export type IAIToggleCollapseProps = PropsWithChildren & {
label: string; label: string;
isOpen: boolean; activeLabel?: string;
onToggle: () => void; defaultIsOpen?: boolean;
withSwitch?: boolean;
}; };
const IAICollapse = (props: IAIToggleCollapseProps) => { const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props; const { label, activeLabel, children, defaultIsOpen = false } = props;
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen });
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
return ( return (
<Box> <Box>
<Flex <Flex
@ -28,6 +31,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
alignItems: 'center', alignItems: 'center',
p: 2, p: 2,
px: 4, px: 4,
gap: 2,
borderTopRadius: 'base', borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base', borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen bg: isOpen
@ -48,9 +52,31 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
}} }}
> >
{label} {label}
<AnimatePresence>
{activeLabel && (
<motion.div
key="statusText"
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<Text
sx={{ color: 'accent.500', _dark: { color: 'accent.300' } }}
>
{activeLabel}
</Text>
</motion.div>
)}
</AnimatePresence>
<Spacer /> <Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
{!withSwitch && (
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
w: '1rem', w: '1rem',
@ -60,7 +86,6 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
transitionDuration: 'normal', transitionDuration: 'normal',
}} }}
/> />
)}
</Flex> </Flex>
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}> <Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box <Box

View File

@ -1,19 +1,20 @@
import { import {
Box,
ChakraProps, ChakraProps,
Flex, Flex,
Icon, Icon,
IconButtonProps,
Image, Image,
useColorMode, useColorMode,
useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities'; import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import {
IAILoadingImageFallback,
IAINoContentFallback,
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent } from 'react'; import { MouseEvent, ReactElement, SyntheticEvent } from 'react';
import { memo, useRef } from 'react'; import { memo, useRef } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -22,81 +23,97 @@ import IAIDropOverlay from './IAIDropOverlay';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
TypesafeDraggableData,
TypesafeDroppableData,
isValidDrop,
useDraggable,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
type IAIDndImageProps = { type IAIDndImageProps = {
image: ImageDTO | null | undefined; imageDTO: ImageDTO | undefined;
onDrop: (droppedImage: ImageDTO) => void;
onReset?: () => void;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void; onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void; onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
resetIconSize?: IconButtonProps['size']; onClick?: (event: MouseEvent<HTMLDivElement>) => void;
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
withResetIcon?: boolean; withResetIcon?: boolean;
resetIcon?: ReactElement;
resetTooltip?: string;
withMetadataOverlay?: boolean; withMetadataOverlay?: boolean;
isDragDisabled?: boolean; isDragDisabled?: boolean;
isDropDisabled?: boolean; isDropDisabled?: boolean;
isUploadDisabled?: boolean; isUploadDisabled?: boolean;
fallback?: ReactElement;
payloadImage?: ImageDTO | null | undefined;
minSize?: number; minSize?: number;
postUploadAction?: PostUploadAction; postUploadAction?: PostUploadAction;
imageSx?: ChakraProps['sx']; imageSx?: ChakraProps['sx'];
fitContainer?: boolean; fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
isSelected?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
}; };
const IAIDndImage = (props: IAIDndImageProps) => { const IAIDndImage = (props: IAIDndImageProps) => {
const { const {
image, imageDTO,
onDrop, onClickReset,
onReset,
onError, onError,
resetIconSize = 'md', onClick,
withResetIcon = false, withResetIcon = false,
withMetadataOverlay = false, withMetadataOverlay = false,
isDropDisabled = false, isDropDisabled = false,
isDragDisabled = false, isDragDisabled = false,
isUploadDisabled = false, isUploadDisabled = false,
fallback = <IAIImageLoadingFallback />,
payloadImage,
minSize = 24, minSize = 24,
postUploadAction, postUploadAction,
imageSx, imageSx,
fitContainer = false, fitContainer = false,
droppableData,
draggableData,
dropLabel,
isSelected = false,
thumbnail = false,
resetTooltip = 'Reset',
resetIcon = <FaUndo />,
noContentFallback = <IAINoContentFallback icon={FaImage} />,
} = props; } = props;
const dndId = useRef(uuidv4());
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const { const dndId = useRef(uuidv4());
isOver,
setNodeRef: setDroppableRef,
active: isDropActive,
} = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: {
handleDrop: onDrop,
},
});
const { const {
attributes, attributes,
listeners, listeners,
setNodeRef: setDraggableRef, setNodeRef: setDraggableRef,
isDragging, isDragging,
active,
} = useDraggable({ } = useDraggable({
id: dndId.current, id: dndId.current,
data: { disabled: isDragDisabled || !imageDTO,
image: payloadImage ? payloadImage : image, data: draggableData,
},
disabled: isDragDisabled || !image,
}); });
const { isOver, setNodeRef: setDroppableRef } = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: droppableData,
});
const setDndRef = useCombinedRefs(setDroppableRef, setDraggableRef);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction, postUploadAction,
isDisabled: isUploadDisabled, isDisabled: isUploadDisabled,
}); });
const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef); const resetIconShadow = useColorModeValue(
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
);
const uploadButtonStyles = isUploadDisabled const uploadButtonStyles = isUploadDisabled
? {} ? {}
@ -117,16 +134,16 @@ const IAIDndImage = (props: IAIDndImageProps) => {
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative', position: 'relative',
minW: minSize, minW: minSize ? minSize : undefined,
minH: minSize, minH: minSize ? minSize : undefined,
userSelect: 'none', userSelect: 'none',
cursor: isDragDisabled || !image ? 'auto' : 'grab', cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
{...attributes} {...attributes}
{...listeners} {...listeners}
ref={setNodeRef} ref={setDndRef}
> >
{image && ( {imageDTO && (
<Flex <Flex
sx={{ sx={{
w: 'full', w: 'full',
@ -137,42 +154,50 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}} }}
> >
<Image <Image
src={image.image_url} onClick={onClick}
fallback={fallback} src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallback={<IAILoadingImageFallback image={imageDTO} />}
onError={onError} onError={onError}
objectFit="contain"
draggable={false} draggable={false}
sx={{ sx={{
objectFit: 'contain',
maxW: 'full', maxW: 'full',
maxH: 'full', maxH: 'full',
borderRadius: 'base', borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx, ...imageSx,
}} }}
/> />
{withMetadataOverlay && <ImageMetadataOverlay image={image} />} {withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
{onReset && withResetIcon && ( {onClickReset && withResetIcon && (
<Box <IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{ sx={{
position: 'absolute', position: 'absolute',
top: 0, top: 1,
right: 0, insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}} }}
>
<IAIIconButton
size={resetIconSize}
tooltip="Reset Image"
aria-label="Reset Image"
icon={<FaUndo />}
onClick={onReset}
/> />
</Box>
)} )}
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex> </Flex>
)} )}
{!image && ( {!imageDTO && !isUploadDisabled && (
<> <>
<Flex <Flex
sx={{ sx={{
@ -191,17 +216,20 @@ const IAIDndImage = (props: IAIDndImageProps) => {
> >
<input {...getUploadInputProps()} /> <input {...getUploadInputProps()} />
<Icon <Icon
as={isUploadDisabled ? FaImage : FaUpload} as={FaUpload}
sx={{ sx={{
boxSize: 12, boxSize: 16,
}} }}
/> />
</Flex> </Flex>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback}
<AnimatePresence>
{isValidDrop(droppableData, active) && !isDragging && (
<IAIDropOverlay isOver={isOver} label={dropLabel} />
)}
</AnimatePresence>
</Flex> </Flex>
); );
}; };

View File

@ -62,7 +62,7 @@ export const IAIDropOverlay = (props: Props) => {
w: 'full', w: 'full',
h: 'full', h: 'full',
opacity: 1, opacity: 1,
borderWidth: 2, borderWidth: 3,
borderColor: isOver borderColor: isOver
? mode('base.50', 'base.200')(colorMode) ? mode('base.50', 'base.200')(colorMode)
: mode('base.100', 'base.500')(colorMode), : mode('base.100', 'base.500')(colorMode),
@ -78,10 +78,10 @@ export const IAIDropOverlay = (props: Props) => {
sx={{ sx={{
fontSize: '2xl', fontSize: '2xl',
fontWeight: 600, fontWeight: 600,
transform: isOver ? 'scale(1.1)' : 'scale(1)', transform: isOver ? 'scale(1.02)' : 'scale(1)',
color: isOver color: isOver
? mode('base.100', 'base.100')(colorMode) ? mode('base.50', 'base.50')(colorMode)
: mode('base.200', 'base.500')(colorMode), : mode('base.100', 'base.200')(colorMode),
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}

View File

@ -29,7 +29,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
<IconButton <IconButton
ref={forwardedRef} ref={forwardedRef}
role={role} role={role}
aria-checked={isChecked !== undefined ? isChecked : undefined} colorScheme={isChecked ? 'accent' : 'base'}
{...rest} {...rest}
/> />
</Tooltip> </Tooltip>

View File

@ -1,73 +1,82 @@
import { import {
As, As,
ChakraProps,
Flex, Flex,
FlexProps,
Icon, Icon,
IconProps, Skeleton,
Spinner, Spinner,
SpinnerProps, StyleProps,
useColorMode, Text,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { FaImage } from 'react-icons/fa'; import { FaImage } from 'react-icons/fa';
import { mode } from 'theme/util/mode'; import { ImageDTO } from 'services/api/types';
type Props = FlexProps & { type Props = { image: ImageDTO | undefined };
spinnerProps?: SpinnerProps;
}; export const IAILoadingImageFallback = (props: Props) => {
if (props.image) {
return (
<Skeleton
sx={{
w: `${props.image.width}px`,
h: 'auto',
objectFit: 'contain',
aspectRatio: `${props.image.width}/${props.image.height}`,
}}
/>
);
}
export const IAIImageLoadingFallback = (props: Props) => {
const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest;
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
bg: mode('base.200', 'base.900')(colorMode),
opacity: 0.7, opacity: 0.7,
w: 'full', w: 'full',
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
borderRadius: 'base', borderRadius: 'base',
...sx, bg: 'base.200',
_dark: {
bg: 'base.900',
},
}} }}
{...restFlexProps}
> >
<Spinner size="xl" {...spinnerProps} /> <Spinner size="xl" />
</Flex> </Flex>
); );
}; };
type IAINoImageFallbackProps = { type IAINoImageFallbackProps = {
flexProps?: FlexProps; label?: string;
iconProps?: IconProps; icon?: As;
as?: As; boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
}; };
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => { export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} }; const { icon = FaImage, boxSize = 16 } = props;
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
bg: mode('base.200', 'base.900')(colorMode),
opacity: 0.7,
w: 'full', w: 'full',
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
borderRadius: 'base', borderRadius: 'base',
...flexSx, flexDir: 'column',
gap: 2,
userSelect: 'none',
color: 'base.700',
_dark: {
color: 'base.500',
},
...props.sx,
}} }}
{...restFlexProps}
> >
<Icon <Icon as={icon} boxSize={boxSize} opacity={0.7} />
as={props.as ?? FaImage} {props.label && <Text textAlign="center">{props.label}</Text>}
sx={{ color: mode('base.700', 'base.500')(colorMode), ...iconSx }}
{...restIconProps}
/>
</Flex> </Flex>
); );
}; };

View File

@ -1,15 +1,16 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo } from 'react'; import { RefObject, memo } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, ...rest } = props; const { searchable = true, tooltip, inputRef, ...rest } = props;
const { const {
base50, base50,
base100, base100,
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<MultiSelect <MultiSelect
ref={inputRef}
searchable={searchable} searchable={searchable}
styles={() => ({ styles={() => ({
label: { label: {
@ -61,7 +63,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&:focus-within': { '&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode), borderColor: mode(accent200, accent600)(colorMode),
}, },
'&:disabled': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
}, },

View File

@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus-within': { '&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode), borderColor: mode(accent200, accent600)(colorMode),
}, },
'&:disabled': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
}, },

View File

@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => {
isDisabled={isDisabled} isDisabled={isDisabled}
width={width} width={width}
display="flex" display="flex"
gap={4}
alignItems="center" alignItems="center"
{...formControlProps} {...formControlProps}
> >
@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => {
sx={{ sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer', cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx, ...formLabelProps?.sx,
pe: 4,
}} }}
{...formLabelProps} {...formLabelProps}
> >

View File

@ -1,27 +1,49 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[generationSelector, systemSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
(generation, system, activeTabName) => { (state, activeTabName) => {
const { generation, system, batch } = state;
const { shouldGenerateVariations, seedWeights, initialImage, seed } = const { shouldGenerateVariations, seedWeights, initialImage, seed } =
generation; generation;
const { isProcessing, isConnected } = system; const { isProcessing, isConnected } = system;
const {
isEnabled: isBatchEnabled,
asInitialImage,
imageNames: batchImageNames,
} = batch;
let isReady = true; let isReady = true;
const reasonsWhyNotReady: string[] = []; const reasonsWhyNotReady: string[] = [];
if (activeTabName === 'img2img' && !initialImage) { if (
activeTabName === 'img2img' &&
!initialImage &&
!(asInitialImage && batchImageNames.length > 1)
) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('No initial image selected'); reasonsWhyNotReady.push('No initial image selected');
} }
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');
}
// TODO: job queue // TODO: job queue
// Cannot generate if already processing an image // Cannot generate if already processing an image
if (isProcessing) { if (isProcessing) {

View File

@ -0,0 +1,67 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
Spacer,
Switch,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import { controlNetToggled } from '../store/batchSlice';
type Props = {
controlNet: ControlNetConfig;
};
const selector = createSelector(
[stateSelector, (state, controlNetId: string) => controlNetId],
(state, controlNetId) => {
const isControlNetEnabled = state.batch.controlNets.includes(controlNetId);
return { isControlNetEnabled };
},
defaultSelectorOptions
);
const BatchControlNet = (props: Props) => {
const dispatch = useAppDispatch();
const { isControlNetEnabled } = useAppSelector((state) =>
selector(state, props.controlNet.controlNetId)
);
const { processorType, model } = props.controlNet;
const handleChangeAsControlNet = useCallback(() => {
dispatch(controlNetToggled(props.controlNet.controlNetId));
}, [dispatch, props.controlNet.controlNetId]);
return (
<Flex
layerStyle="second"
sx={{ flexDir: 'column', gap: 1, p: 4, borderRadius: 'base' }}
>
<Flex sx={{ justifyContent: 'space-between' }}>
<FormControl as={Flex} onClick={handleChangeAsControlNet}>
<FormLabel>
<Heading size="sm">ControlNet</Heading>
</FormLabel>
<Spacer />
<Switch isChecked={isControlNetEnabled} />
</FormControl>
</Flex>
<Text>
<strong>Model:</strong> {model}
</Text>
<Text>
<strong>Processor:</strong> {processorType}
</Text>
</Flex>
);
};
export default memo(BatchControlNet);

View File

@ -0,0 +1,116 @@
import { Box, Icon, Skeleton } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import {
batchImageRangeEndSelected,
batchImageSelected,
batchImageSelectionToggled,
imageRemovedFromBatch,
} from 'features/batch/store/batchSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaExclamationCircle } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
(state) => ({
selectionCount: state.batch.selection.length,
isSelected: state.batch.selection.includes(image_name),
}),
defaultSelectorOptions
);
type BatchImageProps = {
imageName: string;
};
const BatchImage = (props: BatchImageProps) => {
const {
currentData: imageDTO,
isFetching,
isError,
isSuccess,
} = useGetImageDTOQuery(props.imageName);
const dispatch = useAppDispatch();
const selector = useMemo(
() => makeSelector(props.imageName),
[props.imageName]
);
const { isSelected, selectionCount } = useAppSelector(selector);
const handleClickRemove = useCallback(() => {
dispatch(imageRemovedFromBatch(props.imageName));
}, [dispatch, props.imageName]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(batchImageRangeEndSelected(props.imageName));
} else if (e.ctrlKey || e.metaKey) {
dispatch(batchImageSelectionToggled(props.imageName));
} else {
dispatch(batchImageSelected(props.imageName));
}
},
[dispatch, props.imageName]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selectionCount > 1) {
return {
id: 'batch',
payloadType: 'BATCH_SELECTION',
};
}
if (imageDTO) {
return {
id: 'batch',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selectionCount]);
if (isError) {
return <Icon as={FaExclamationCircle} />;
}
if (isFetching) {
return (
<Skeleton>
<Box w="full" h="full" aspectRatio="1/1" />
</Skeleton>
);
}
return (
<Box sx={{ position: 'relative', aspectRatio: '1/1' }}>
<IAIDndImage
imageDTO={imageDTO}
draggableData={draggableData}
isDropDisabled={true}
isUploadDisabled={true}
imageSx={{
w: 'full',
h: 'full',
}}
onClick={handleClick}
isSelected={isSelected}
onClickReset={handleClickRemove}
resetTooltip="Remove from batch"
withResetIcon
thumbnail
/>
</Box>
);
};
export default memo(BatchImage);

View File

@ -0,0 +1,31 @@
import { Box } from '@chakra-ui/react';
import BatchImageGrid from './BatchImageGrid';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import {
AddToBatchDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const droppableData: AddToBatchDropData = {
id: 'batch',
actionType: 'ADD_TO_BATCH',
};
const BatchImageContainer = () => {
const { isOver, setNodeRef, active } = useDroppable({
id: 'batch-manager',
data: droppableData,
});
return (
<Box ref={setNodeRef} position="relative" w="full" h="full">
<BatchImageGrid />
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} label="Add to Batch" />
)}
</Box>
);
};
export default BatchImageContainer;

View File

@ -0,0 +1,54 @@
import { FaImages } from 'react-icons/fa';
import { Grid, GridItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import BatchImage from './BatchImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
const selector = createSelector(
stateSelector,
(state) => {
const imageNames = state.batch.imageNames.concat().reverse();
return { imageNames };
},
defaultSelectorOptions
);
const BatchImageGrid = () => {
const { imageNames } = useAppSelector(selector);
if (imageNames.length === 0) {
return (
<IAINoContentFallback
icon={FaImages}
boxSize={16}
label="No images in Batch"
/>
);
}
return (
<Grid
sx={{
position: 'absolute',
flexWrap: 'wrap',
w: 'full',
minH: 0,
maxH: 'full',
overflowY: 'scroll',
gridTemplateColumns: `repeat(auto-fill, minmax(128px, 1fr))`,
}}
>
{imageNames.map((imageName) => (
<GridItem key={imageName} sx={{ p: 1.5 }}>
<BatchImage imageName={imageName} />
</GridItem>
))}
</Grid>
);
};
export default BatchImageGrid;

View File

@ -0,0 +1,103 @@
import { Flex, Heading, Spacer } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import IAISwitch from 'common/components/IAISwitch';
import {
asInitialImageToggled,
batchReset,
isEnabledChanged,
} from 'features/batch/store/batchSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import BatchImageContainer from './BatchImageGrid';
import { map } from 'lodash-es';
import BatchControlNet from './BatchControlNet';
const selector = createSelector(
stateSelector,
(state) => {
const { controlNets } = state.controlNet;
const {
imageNames,
asInitialImage,
controlNets: batchControlNets,
isEnabled,
} = state.batch;
return {
imageCount: imageNames.length,
asInitialImage,
controlNets,
batchControlNets,
isEnabled,
};
},
defaultSelectorOptions
);
const BatchManager = () => {
const dispatch = useAppDispatch();
const { imageCount, isEnabled, controlNets, batchControlNets } =
useAppSelector(selector);
const handleResetBatch = useCallback(() => {
dispatch(batchReset());
}, [dispatch]);
const handleToggle = useCallback(() => {
dispatch(isEnabledChanged(!isEnabled));
}, [dispatch, isEnabled]);
const handleChangeAsInitialImage = useCallback(() => {
dispatch(asInitialImageToggled());
}, [dispatch]);
return (
<Flex
sx={{
h: 'full',
w: 'full',
flexDir: 'column',
position: 'relative',
gap: 2,
minW: 0,
}}
>
<Flex sx={{ alignItems: 'center' }}>
<Heading
size={'md'}
sx={{ color: 'base.800', _dark: { color: 'base.200' } }}
>
{imageCount || 'No'} images
</Heading>
<Spacer />
<IAIButton onClick={handleResetBatch}>Reset</IAIButton>
</Flex>
<Flex
sx={{
alignItems: 'center',
flexDir: 'column',
gap: 4,
}}
>
<IAISwitch
label="Use as Initial Image"
onChange={handleChangeAsInitialImage}
/>
{map(controlNets, (controlNet) => {
return (
<BatchControlNet
key={controlNet.controlNetId}
controlNet={controlNet}
/>
);
})}
</Flex>
<BatchImageContainer />
</Flex>
);
};
export default BatchManager;

View File

@ -0,0 +1,142 @@
import { PayloadAction, createAction, createSlice } from '@reduxjs/toolkit';
import { uniq } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
type BatchState = {
isEnabled: boolean;
imageNames: string[];
asInitialImage: boolean;
controlNets: string[];
selection: string[];
};
export const initialBatchState: BatchState = {
isEnabled: false,
imageNames: [],
asInitialImage: false,
controlNets: [],
selection: [],
};
const batch = createSlice({
name: 'batch',
initialState: initialBatchState,
reducers: {
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isEnabled = action.payload;
},
imageAddedToBatch: (state, action: PayloadAction<string>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imagesAddedToBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imageRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.imageNames = state.imageNames.filter(
(imageName) => action.payload !== imageName
);
state.selection = state.selection.filter(
(imageName) => action.payload !== imageName
);
},
imagesRemovedFromBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = state.imageNames.filter(
(imageName) => !action.payload.includes(imageName)
);
state.selection = state.selection.filter(
(imageName) => !action.payload.includes(imageName)
);
},
batchImageRangeEndSelected: (state, action: PayloadAction<string>) => {
const rangeEndImageName = action.payload;
const lastSelectedImage = state.selection[state.selection.length - 1];
const lastClickedIndex = state.imageNames.findIndex(
(n) => n === lastSelectedImage
);
const currentClickedIndex = state.imageNames.findIndex(
(n) => n === rangeEndImageName
);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = state.imageNames.slice(start, end + 1);
state.selection = uniq(state.selection.concat(imagesToSelect));
}
},
batchImageSelectionToggled: (state, action: PayloadAction<string>) => {
if (
state.selection.includes(action.payload) &&
state.selection.length > 1
) {
state.selection = state.selection.filter(
(imageName) => imageName !== action.payload
);
} else {
state.selection = uniq(state.selection.concat(action.payload));
}
},
batchImageSelected: (state, action: PayloadAction<string | null>) => {
state.selection = action.payload
? [action.payload]
: [String(state.imageNames[0])];
},
batchReset: (state) => {
state.imageNames = [];
state.selection = [];
},
asInitialImageToggled: (state) => {
state.asInitialImage = !state.asInitialImage;
},
controlNetAddedToBatch: (state, action: PayloadAction<string>) => {
state.controlNets = uniq(state.controlNets.concat(action.payload));
},
controlNetRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
},
controlNetToggled: (state, action: PayloadAction<string>) => {
if (state.controlNets.includes(action.payload)) {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
} else {
state.controlNets = uniq(state.controlNets.concat(action.payload));
}
},
},
extraReducers: (builder) => {
builder.addCase(imageDeleted.fulfilled, (state, action) => {
state.imageNames = state.imageNames.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
state.selection = state.selection.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
});
},
});
export const {
isEnabledChanged,
imageAddedToBatch,
imagesAddedToBatch,
imageRemovedFromBatch,
imagesRemovedFromBatch,
asInitialImageToggled,
controlNetAddedToBatch,
controlNetRemovedFromBatch,
batchReset,
controlNetToggled,
batchImageRangeEndSelected,
batchImageSelectionToggled,
batchImageSelected,
} = batch.actions;
export default batch.reducer;
export const selectionAddedToBatch = createAction(
'batch/selectionAddedToBatch'
);

View File

@ -1,20 +1,22 @@
import { memo, useCallback, useState } from 'react'; import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import { ImageDTO } from 'services/api/types'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import { import {
ControlNetConfig, ControlNetConfig,
controlNetImageChanged, controlNetImageChanged,
controlNetSelector, controlNetSelector,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
@ -57,22 +59,6 @@ const ControlNetImagePreview = (props: Props) => {
isSuccess: isSuccessProcessedControlImage, isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken); } = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
);
},
[controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -84,6 +70,30 @@ const ControlNetImagePreview = (props: Props) => {
setIsMouseOverImage(false); setIsMouseOverImage(false);
}, []); }, []);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id: controlNetId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
}),
[controlNetId]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
[controlNetId]
);
const shouldShowProcessedImage = const shouldShowProcessedImage =
controlImage && controlImage &&
processedControlImage && processedControlImage &&
@ -104,14 +114,14 @@ const ControlNetImagePreview = (props: Props) => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={controlImage} draggableData={draggableData}
onDrop={handleDrop} droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage} isDropDisabled={shouldShowProcessedImage}
postUploadAction={{ type: 'SET_CONTROLNET_IMAGE', controlNetId }} onClickReset={handleResetControlImage}
imageSx={{ postUploadAction={postUploadAction}
w: 'full', resetTooltip="Reset Control Image"
h: 'full', withResetIcon={Boolean(controlImage)}
}}
/> />
<Box <Box
sx={{ sx={{
@ -127,14 +137,13 @@ const ControlNetImagePreview = (props: Props) => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={processedControlImage} draggableData={draggableData}
onDrop={handleDrop} droppableData={droppableData}
payloadImage={controlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
imageSx={{ onClickReset={handleResetControlImage}
w: 'full', resetTooltip="Reset Control Image"
h: 'full', withResetIcon={Boolean(controlImage)}
}}
/> />
</Box> </Box>
{pendingControlImages.includes(controlNetId) && ( {pendingControlImages.includes(controlNetId) && (
@ -145,27 +154,12 @@ const ControlNetImagePreview = (props: Props) => {
insetInlineStart: 0, insetInlineStart: 0,
w: 'full', w: 'full',
h: 'full', h: 'full',
objectFit: 'contain',
}} }}
> >
<IAIImageLoadingFallback /> <IAILoadingImageFallback image={controlImage} />
</Box> </Box>
)} )}
{controlImage && (
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0 }}>
<IAIIconButton
aria-label="Reset Control Image"
tooltip="Reset Control Image"
size="sm"
onClick={handleResetControlImage}
icon={<FaUndo />}
variant="link"
sx={{
p: 2,
color: 'base.50',
}}
/>
</Flex>
)}
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.controlNet;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamControlNetFeatureToggle = () => {
const { isEnabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(isControlNetEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
/>
);
};
export default ParamControlNetFeatureToggle;

View File

@ -0,0 +1,15 @@
import { filter } from 'lodash-es';
import { ControlNetConfig } from '../store/controlNetSlice';
export const getValidControlNets = (
controlNets: Record<string, ControlNetConfig>
) => {
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
return validControlNets;
};

View File

@ -1,40 +1,30 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react'; import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { isEnabled } = state.dynamicPrompts; const { isEnabled } = state.dynamicPrompts;
return { isEnabled }; return { activeLabel: isEnabled ? 'Enabled' : undefined };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsCollapse = () => { const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch(); const { activeLabel } = useAppSelector(selector);
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return ( return (
<IAICollapse <IAICollapse label="Dynamic Prompts" activeLabel={activeLabel}>
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}> <Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsToggle />
<ParamDynamicPromptsCombinatorial /> <ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts /> <ParamDynamicPromptsMaxPrompts />
</Flex> </Flex>

View File

@ -1,23 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { combinatorialToggled } from '../store/slice';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { combinatorial } = state.dynamicPrompts; const { combinatorial, isEnabled } = state.dynamicPrompts;
return { combinatorial }; return { combinatorial, isDisabled: !isEnabled };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsCombinatorial = () => { const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector); const { combinatorial, isDisabled } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback(() => { const handleChange = useCallback(() => {
@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => {
return ( return (
<IAISwitch <IAISwitch
isDisabled={isDisabled}
label="Combinatorial Generation" label="Combinatorial Generation"
isChecked={combinatorial} isChecked={combinatorial}
onChange={handleChange} onChange={handleChange}

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsToggle = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable Dynamic Prompts"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
);
};
export default ParamDynamicPromptsToggle;

View File

@ -1,25 +1,31 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useCallback } from 'react';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { maxPrompts, combinatorial } = state.dynamicPrompts; const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts;
const { min, sliderMax, inputMax } = const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts; state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax, combinatorial }; return {
maxPrompts,
min,
sliderMax,
inputMax,
isDisabled: !isEnabled || !combinatorial,
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsMaxPrompts = () => { const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax, combinatorial } = const { maxPrompts, min, sliderMax, inputMax, isDisabled } =
useAppSelector(selector); useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return ( return (
<IAISlider <IAISlider
label="Max Prompts" label="Max Prompts"
isDisabled={!combinatorial} isDisabled={isDisabled}
min={min} min={min}
max={sliderMax} max={sliderMax}
value={maxPrompts} value={maxPrompts}

View File

@ -0,0 +1,33 @@
import IAIIconButton from 'common/components/IAIIconButton';
import { memo } from 'react';
import { BiCode } from 'react-icons/bi';
type Props = {
onClick: () => void;
};
const AddEmbeddingButton = (props: Props) => {
const { onClick } = props;
return (
<IAIIconButton
size="sm"
aria-label="Add Embedding"
tooltip="Add Embedding"
icon={<BiCode />}
sx={{
p: 2,
color: 'base.700',
_hover: {
color: 'base.550',
},
_active: {
color: 'base.500',
},
}}
variant="link"
onClick={onClick}
/>
);
};
export default memo(AddEmbeddingButton);

View File

@ -0,0 +1,151 @@
import {
Flex,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Text,
} from '@chakra-ui/react';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import {
PropsWithChildren,
forwardRef,
useCallback,
useMemo,
useRef,
} from 'react';
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
type EmbeddingSelectItem = {
label: string;
value: string;
description?: string;
};
type Props = PropsWithChildren & {
onSelect: (v: string) => void;
isOpen: boolean;
onClose: () => void;
};
const ParamEmbeddingPopover = (props: Props) => {
const { onSelect, isOpen, onClose, children } = props;
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
const inputRef = useRef<HTMLInputElement>(null);
const data = useMemo(() => {
if (!embeddingQueryData) {
return [];
}
const data: EmbeddingSelectItem[] = [];
forEach(embeddingQueryData.entities, (embedding, _) => {
if (!embedding) return;
data.push({
value: embedding.name,
label: embedding.name,
description: embedding.description,
});
});
return data;
}, [embeddingQueryData]);
const handleChange = useCallback(
(v: string[]) => {
if (v.length === 0) {
return;
}
onSelect(v[0]);
},
[onSelect]
);
return (
<Popover
initialFocusRef={inputRef}
isOpen={isOpen}
onClose={onClose}
placement="bottom"
openDelay={0}
closeDelay={0}
closeOnBlur={true}
returnFocusOnClose={true}
>
<PopoverTrigger>{children}</PopoverTrigger>
<PopoverContent
sx={{
p: 0,
top: -1,
shadow: 'dark-lg',
borderColor: 'accent.300',
borderWidth: '2px',
borderStyle: 'solid',
_dark: { borderColor: 'accent.400' },
}}
>
<PopoverBody
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
>
{data.length === 0 ? (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
>
No Embeddings Loaded
</Text>
</Flex>
) : (
<IAIMantineMultiSelect
inputRef={inputRef}
placeholder={'Add Embedding'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No Matching Embeddings"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: EmbeddingSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
)}
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default ParamEmbeddingPopover;
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';

View File

@ -1,16 +1,16 @@
import { Flex, Text, useColorMode } from '@chakra-ui/react'; import { Flex, useColorMode } from '@chakra-ui/react';
import { FaImages } from 'react-icons/fa'; import { FaImages } from 'react-icons/fa';
import { boardIdSelected } from '../../store/boardSlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useDispatch } from 'react-redux'; import { useDispatch } from 'react-redux';
import { IAINoImageFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { useCallback } from 'react';
import { ImageDTO } from 'services/api/types';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useDroppable } from '@dnd-kit/core';
import IAIDropOverlay from 'common/components/IAIDropOverlay'; import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch(); const dispatch = useDispatch();
@ -20,31 +20,15 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
dispatch(boardIdSelected()); dispatch(boardIdSelected());
}; };
const [removeImageFromBoard, { isLoading }] = const droppableData: MoveBoardDropData = {
useRemoveImageFromBoardMutation(); id: 'all-images-board',
actionType: 'MOVE_BOARD',
context: { boardId: null },
};
const handleDrop = useCallback( const { isOver, setNodeRef, active } = useDroppable({
(droppedImage: ImageDTO) => {
if (!droppedImage.board_id) {
return;
}
removeImageFromBoard({
board_id: droppedImage.board_id,
image_name: droppedImage.image_name,
});
},
[removeImageFromBoard]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_all_images`, id: `board_droppable_all_images`,
data: { data: droppableData,
handleDrop,
},
}); });
return ( return (
@ -58,10 +42,10 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
h: 'full', h: 'full',
borderRadius: 'base', borderRadius: 'base',
}} }}
onClick={handleAllImagesBoardClick}
> >
<Flex <Flex
ref={setNodeRef} ref={setNodeRef}
onClick={handleAllImagesBoardClick}
sx={{ sx={{
position: 'relative', position: 'relative',
justifyContent: 'center', justifyContent: 'center',
@ -69,18 +53,30 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
borderRadius: 'base', borderRadius: 'base',
w: 'full', w: 'full',
aspectRatio: '1/1', aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} /> <IAINoContentFallback
boxSize={8}
icon={FaImages}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<AnimatePresence> <AnimatePresence>
{isSelected && <SelectedItemOverlay />} {isValidDrop(droppableData, active) && (
</AnimatePresence> <IAIDropOverlay isOver={isOver} />
<AnimatePresence> )}
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence> </AnimatePresence>
</Flex> </Flex>
<Text <Flex
sx={{ sx={{
h: 'full',
alignItems: 'center',
color: isSelected color: isSelected
? mode('base.900', 'base.50')(colorMode) ? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode), : mode('base.700', 'base.200')(colorMode),
@ -89,7 +85,7 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
}} }}
> >
All Images All Images
</Text> </Flex>
</Flex> </Flex>
); );
}; };

View File

@ -2,6 +2,7 @@ import {
Collapse, Collapse,
Flex, Flex,
Grid, Grid,
GridItem,
IconButton, IconButton,
Input, Input,
InputGroup, InputGroup,
@ -10,10 +11,7 @@ import {
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { import { setBoardSearchText } from 'features/gallery/store/boardSlice';
boardsSelector,
setBoardSearchText,
} from 'features/gallery/store/boardSlice';
import { memo, useState } from 'react'; import { memo, useState } from 'react';
import HoverableBoard from './HoverableBoard'; import HoverableBoard from './HoverableBoard';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@ -21,11 +19,13 @@ import AddBoardButton from './AddBoardButton';
import AllImagesBoard from './AllImagesBoard'; import AllImagesBoard from './AllImagesBoard';
import { CloseIcon } from '@chakra-ui/icons'; import { CloseIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { stateSelector } from 'app/store/store';
const selector = createSelector( const selector = createSelector(
[boardsSelector], [stateSelector],
(boardsState) => { ({ boards, gallery }) => {
const { selectedBoardId, searchText } = boardsState; const { searchText } = boards;
const { selectedBoardId } = gallery;
return { selectedBoardId, searchText }; return { selectedBoardId, searchText };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -109,20 +109,24 @@ const BoardsList = (props: Props) => {
<Grid <Grid
className="list-container" className="list-container"
sx={{ sx={{
gap: 2, gridTemplateRows: '6.5rem 6.5rem',
gridTemplateRows: '5.5rem 5.5rem',
gridAutoFlow: 'column dense', gridAutoFlow: 'column dense',
gridAutoColumns: '4rem', gridAutoColumns: '5rem',
}} }}
> >
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />} {!searchMode && (
<GridItem sx={{ p: 1.5 }}>
<AllImagesBoard isSelected={!selectedBoardId} />
</GridItem>
)}
{filteredBoards && {filteredBoards &&
filteredBoards.map((board) => ( filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
<HoverableBoard <HoverableBoard
key={board.board_id}
board={board} board={board}
isSelected={selectedBoardId === board.board_id} isSelected={selectedBoardId === board.board_id}
/> />
</GridItem>
))} ))}
</Grid> </Grid>
</OverlayScrollbarsComponent> </OverlayScrollbarsComponent>

View File

@ -15,10 +15,9 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback, useContext } from 'react'; import { memo, useCallback, useContext } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa'; import { FaFolder, FaTrash } from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import { BoardDTO, ImageDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
import { IAINoImageFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/boardSlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useAddImageToBoardMutation } from 'services/api/endpoints/boardImages';
import { import {
useDeleteBoardMutation, useDeleteBoardMutation,
useUpdateBoardMutation, useUpdateBoardMutation,
@ -26,12 +25,15 @@ import {
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useDroppable } from '@dnd-kit/core';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import IAIDropOverlay from 'common/components/IAIDropOverlay'; import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext'; import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
interface HoverableBoardProps { interface HoverableBoardProps {
board: BoardDTO; board: BoardDTO;
@ -61,9 +63,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
const [deleteBoard, { isLoading: isDeleteBoardLoading }] = const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
useDeleteBoardMutation(); useDeleteBoardMutation();
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
useAddImageToBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => { const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } }); updateBoard({ board_id, changes: { board_name: newBoardName } });
}; };
@ -77,29 +76,19 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
onClickDeleteBoardImages(board); onClickDeleteBoardImages(board);
}, [board, onClickDeleteBoardImages]); }, [board, onClickDeleteBoardImages]);
const handleDrop = useCallback( const droppableData: MoveBoardDropData = {
(droppedImage: ImageDTO) => { id: board_id,
if (droppedImage.board_id === board_id) { actionType: 'MOVE_BOARD',
return; context: { boardId: board_id },
} };
addImageToBoard({ board_id, image_name: droppedImage.image_name });
},
[addImageToBoard, board_id]
);
const { const { isOver, setNodeRef, active } = useDroppable({
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_${board_id}`, id: `board_droppable_${board_id}`,
data: { data: droppableData,
handleDrop,
},
}); });
return ( return (
<Box sx={{ touchAction: 'none' }}> <Box sx={{ touchAction: 'none', height: 'full' }}>
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( renderMenu={() => (
@ -148,13 +137,25 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
w: 'full', w: 'full',
aspectRatio: '1/1', aspectRatio: '1/1',
overflow: 'hidden', overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
{board.cover_image_name && coverImage?.image_url && ( {board.cover_image_name && coverImage?.image_url && (
<Image src={coverImage?.image_url} draggable={false} /> <Image src={coverImage?.image_url} draggable={false} />
)} )}
{!(board.cover_image_name && coverImage?.image_url) && ( {!(board.cover_image_name && coverImage?.image_url) && (
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} /> <IAINoContentFallback
boxSize={8}
icon={FaFolder}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: {
border: '2px solid var(--invokeai-colors-base-800)',
},
}}
/>
)} )}
<Flex <Flex
sx={{ sx={{
@ -167,14 +168,20 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
<Badge variant="solid">{board.image_count}</Badge> <Badge variant="solid">{board.image_count}</Badge>
</Flex> </Flex>
<AnimatePresence> <AnimatePresence>
{isSelected && <SelectedItemOverlay />} {isValidDrop(droppableData, active) && (
</AnimatePresence> <IAIDropOverlay isOver={isOver} />
<AnimatePresence> )}
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence> </AnimatePresence>
</Flex> </Flex>
<Box sx={{ width: 'full' }}> <Flex
sx={{
width: 'full',
height: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Editable <Editable
defaultValue={board_name} defaultValue={board_name}
submitOnBlur={false} submitOnBlur={false}
@ -204,7 +211,7 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
}} }}
/> />
</Editable> </Editable>
</Box> </Flex>
</Flex> </Flex>
)} )}
</ContextMenu> </ContextMenu>

View File

@ -38,8 +38,7 @@ import {
FaShare, FaShare,
FaShareAlt, FaShareAlt,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { useCallback } from 'react';
import { useCallback, useContext } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
@ -49,22 +48,15 @@ import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceR
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings'; import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { stateSelector } from 'app/store/store';
import { DeleteImageButton } from './DeleteImageModal'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { selectImagesById } from '../store/imagesSlice'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { RootState } from 'app/store/store'; import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[ [stateSelector, activeTabNameSelector],
(state: RootState) => state, ({ gallery, system, postprocessing, ui, lightbox }, activeTabName) => {
systemSelector,
gallerySelector,
postprocessingSelector,
uiSelector,
lightboxSelector,
activeTabNameSelector,
],
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
const { const {
isProcessing, isProcessing,
isConnected, isConnected,
@ -84,9 +76,7 @@ const currentImageButtonsSelector = createSelector(
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = ui; } = ui;
const imageDTO = selectImagesById(state, gallery.selectedImage ?? ''); const lastSelectedImage = gallery.selection[gallery.selection.length - 1];
const { selectedImage } = gallery;
return { return {
canDeleteImage: isConnected && !isProcessing, canDeleteImage: isConnected && !isProcessing,
@ -97,16 +87,13 @@ const currentImageButtonsSelector = createSelector(
isESRGANAvailable, isESRGANAvailable,
upscalingLevel, upscalingLevel,
facetoolStrength, facetoolStrength,
shouldDisableToolbarButtons: Boolean(progressImage) || !selectedImage, shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
shouldShowImageDetails, shouldShowImageDetails,
activeTabName, activeTabName,
isLightboxOpen, isLightboxOpen,
shouldHidePreview, shouldHidePreview,
image: imageDTO,
seed: imageDTO?.metadata?.seed,
prompt: imageDTO?.metadata?.positive_conditioning,
negativePrompt: imageDTO?.metadata?.negative_conditioning,
shouldShowProgressInViewer, shouldShowProgressInViewer,
lastSelectedImage,
}; };
}, },
{ {
@ -132,7 +119,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isLightboxOpen, isLightboxOpen,
activeTabName, activeTabName,
shouldHidePreview, shouldHidePreview,
image, lastSelectedImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
@ -147,7 +134,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
const { onDelete } = useContext(DeleteImageContext); const { currentData: image } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken
);
// const handleCopyImage = useCallback(async () => { // const handleCopyImage = useCallback(async () => {
// if (!image?.url) { // if (!image?.url) {
@ -248,8 +237,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, []); }, []);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
onDelete(image); if (!image) {
}, [image, onDelete]); return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
useHotkeys( useHotkeys(
'Shift+U', 'Shift+U',
@ -371,7 +363,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}} }}
{...props} {...props}
> >
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
@ -444,11 +436,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} }
isChecked={isLightboxOpen} isChecked={isLightboxOpen}
onClick={handleLightBox} onClick={handleLightBox}
isDisabled={shouldDisableToolbarButtons}
/> />
)} )}
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
@ -478,7 +471,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
{(isUpscalingEnabled || isFaceRestoreEnabled) && ( {(isUpscalingEnabled || isFaceRestoreEnabled) && (
<ButtonGroup isAttached={true}> <ButtonGroup
isAttached={true}
isDisabled={shouldDisableToolbarButtons}
>
{isFaceRestoreEnabled && ( {isFaceRestoreEnabled && (
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
@ -543,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
)} )}
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
icon={<FaCode />} icon={<FaCode />}
tooltip={`${t('parameters.info')} (I)`} tooltip={`${t('parameters.info')} (I)`}
@ -553,7 +549,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/> />
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
aria-label={t('settings.displayInProgress')} aria-label={t('settings.displayInProgress')}
tooltip={t('settings.displayInProgress')} tooltip={t('settings.displayInProgress')}
@ -564,7 +560,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<DeleteImageButton onClick={handleDelete} /> <DeleteImageButton
onClick={handleDelete}
isDisabled={shouldDisableToolbarButtons}
/>
</ButtonGroup> </ButtonGroup>
</Flex> </Flex>
</> </>

View File

@ -1,29 +1,9 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons'; import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview'; import CurrentImagePreview from './CurrentImagePreview';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const currentImageDisplaySelector = createSelector(
[systemSelector, gallerySelector],
(system, gallery) => {
const { progressImage } = system;
return {
hasSelectedImage: Boolean(gallery.selectedImage),
hasProgressImage: Boolean(progressImage),
};
},
defaultSelectorOptions
);
const CurrentImageDisplay = () => { const CurrentImageDisplay = () => {
const { hasSelectedImage } = useAppSelector(currentImageDisplaySelector);
return ( return (
<Flex <Flex
sx={{ sx={{
@ -36,7 +16,7 @@ const CurrentImageDisplay = () => {
justifyContent: 'center', justifyContent: 'center',
}} }}
> >
{hasSelectedImage && <CurrentImageButtons />} <CurrentImageButtons />
<CurrentImagePreview /> <CurrentImagePreview />
</Flex> </Flex>
); );

View File

@ -1,35 +1,33 @@
import { Box, Flex, Image } from '@chakra-ui/react'; import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { uiSelector } from 'features/ui/store/uiSelectors'; import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react';
import { gallerySelector } from '../store/gallerySelectors'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
import { imageSelected } from '../store/gallerySlice';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api/types';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector], [stateSelector, selectLastSelectedImage],
(ui, gallery, system) => { ({ ui, system }, lastSelectedImage) => {
const { const {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = ui; } = ui;
const { selectedImage } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system; const { progressImage, shouldAntialiasProgressImage } = system;
return { return {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
selectedImage, imageName: lastSelectedImage,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
@ -45,29 +43,35 @@ export const imagesSelector = createSelector(
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const { const {
shouldShowImageDetails, shouldShowImageDetails,
selectedImage, imageName,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
const { const {
currentData: image, currentData: imageDTO,
isLoading, isLoading,
isError, isError,
isSuccess, isSuccess,
} = useGetImageDTOQuery(selectedImage ?? skipToken); } = useGetImageDTOQuery(imageName ?? skipToken);
const dispatch = useAppDispatch(); const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (imageDTO) {
const handleDrop = useCallback( return {
(droppedImage: ImageDTO) => { id: 'current-image',
if (droppedImage.image_name === image?.image_name) { payloadType: 'IMAGE_DTO',
return; payload: { imageDTO },
};
} }
dispatch(imageSelected(droppedImage.image_name)); }, [imageDTO]);
},
[dispatch, image?.image_name] const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
}),
[]
); );
return ( return (
@ -98,14 +102,15 @@ const CurrentImagePreview = () => {
/> />
) : ( ) : (
<IAIDndImage <IAIDndImage
image={image} imageDTO={imageDTO}
onDrop={handleDrop} droppableData={droppableData}
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />} draggableData={draggableData}
isUploadDisabled={true} isUploadDisabled={true}
fitContainer fitContainer
dropLabel="Set as Current Image"
/> />
)} )}
{shouldShowImageDetails && image && ( {shouldShowImageDetails && imageDTO && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -116,10 +121,10 @@ const CurrentImagePreview = () => {
overflow: 'scroll', overflow: 'scroll',
}} }}
> >
<ImageMetadataViewer image={image} /> <ImageMetadataViewer image={imageDTO} />
</Box> </Box>
)} )}
{!shouldShowImageDetails && image && ( {!shouldShowImageDetails && imageDTO && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',

View File

@ -1,166 +0,0 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Divider,
Flex,
ListItem,
Text,
UnorderedList,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
DeleteImageContext,
ImageUsage,
} from 'app/contexts/DeleteImageContext';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { configSelector } from 'features/system/store/configSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useContext, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const selector = createSelector(
[systemSelector, configSelector],
(system, config) => {
const { shouldConfirmOnDelete } = system;
const { canRestoreDeletedImagesFromBin } = config;
return {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
};
},
defaultSelectorOptions
);
const ImageInUseMessage = (props: { imageUsage?: ImageUsage }) => {
const { imageUsage } = props;
if (!imageUsage) {
return null;
}
if (!some(imageUsage)) {
return null;
}
return (
<>
<Text>This image is currently in use in the following features:</Text>
<UnorderedList sx={{ paddingInlineStart: 6 }}>
{imageUsage.isInitialImage && <ListItem>Image to Image</ListItem>}
{imageUsage.isCanvasImage && <ListItem>Unified Canvas</ListItem>}
{imageUsage.isControlNetImage && <ListItem>ControlNet</ListItem>}
{imageUsage.isNodesImage && <ListItem>Node Editor</ListItem>}
</UnorderedList>
<Text>
If you delete this image, those features will immediately be reset.
</Text>
</>
);
};
const DeleteImageModal = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { isOpen, onClose, onImmediatelyDelete, image, imageUsage } =
useContext(DeleteImageContext);
const { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin } =
useAppSelector(selector);
const handleChangeShouldConfirmOnDelete = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldConfirmOnDelete(!e.target.checked)),
[dispatch]
);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('gallery.deleteImage')}
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction="column" gap={3}>
<ImageInUseMessage imageUsage={imageUsage} />
<Divider />
<Text>
{canRestoreDeletedImagesFromBin
? t('gallery.deleteImageBin')
: t('gallery.deleteImagePermanent')}
</Text>
<Text>{t('common.areYouSure')}</Text>
<IAISwitch
label={t('common.dontAskMeAgain')}
isChecked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton ref={cancelRef} onClick={onClose}>
Cancel
</IAIButton>
<IAIButton colorScheme="error" onClick={onImmediatelyDelete} ml={3}>
Delete
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(DeleteImageModal);
const deleteImageButtonsSelector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected } = system;
return isConnected && !isProcessing;
}
);
type DeleteImageButtonProps = {
onClick: () => void;
};
export const DeleteImageButton = (props: DeleteImageButtonProps) => {
const { onClick } = props;
const { t } = useTranslation();
const canDeleteImage = useAppSelector(deleteImageButtonsSelector);
return (
<IAIIconButton
onClick={onClick}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!canDeleteImage}
colorScheme="error"
/>
);
};

View File

@ -0,0 +1,132 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import {
imageRangeEndSelected,
imageSelected,
imageSelectionToggled,
} from '../store/gallerySlice';
import ImageContextMenu from './ImageContextMenu';
export const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
({ gallery }) => {
const isSelected = gallery.selection.includes(image_name);
const selectionCount = gallery.selection.length;
return {
isSelected,
selectionCount,
};
},
defaultSelectorOptions
);
interface HoverableImageProps {
imageDTO: ImageDTO;
}
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const GalleryImage = (props: HoverableImageProps) => {
const { imageDTO } = props;
const { image_url, thumbnail_url, image_name } = imageDTO;
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
const { isSelected, selectionCount } = useAppSelector(localSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(imageRangeEndSelected(props.imageDTO.image_name));
} else if (e.ctrlKey || e.metaKey) {
dispatch(imageSelectionToggled(props.imageDTO.image_name));
} else {
dispatch(imageSelected(props.imageDTO.image_name));
}
},
[dispatch, props.imageDTO.image_name]
);
const handleDelete = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
if (!imageDTO) {
return;
}
dispatch(imageToDeleteSelected(imageDTO));
},
[dispatch, imageDTO]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selectionCount > 1) {
return {
id: 'gallery-image',
payloadType: 'GALLERY_SELECTION',
};
}
if (imageDTO) {
return {
id: 'gallery-image',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selectionCount]);
return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu image={imageDTO}>
{(ref) => (
<Box
position="relative"
key={image_name}
userSelect="none"
ref={ref}
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
}}
>
<IAIDndImage
onClick={handleClick}
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
minSize={0}
onClickReset={handleDelete}
resetIcon={<FaTrash />}
resetTooltip="Delete image"
imageSx={{ w: 'full', h: 'full' }}
// withResetIcon // removed bc it's too easy to accidentally delete images
isDropDisabled={true}
isUploadDisabled={true}
/>
</Box>
)}
</ImageContextMenu>
</Box>
);
};
export default memo(GalleryImage);

View File

@ -1,371 +0,0 @@
import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useContext, useState } from 'react';
import {
FaCheck,
FaExpand,
FaFolder,
FaImage,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import IAIIconButton from 'common/components/IAIIconButton';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import { useDraggable } from '@dnd-kit/core';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
(gallery, system, lightbox, activeTabName) => {
const {
galleryImageObjectFit,
galleryImageMinimumWidth,
shouldUseSingleGalleryColumn,
} = gallery;
const { isLightboxOpen } = lightbox;
const { isConnected, isProcessing, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
galleryImageObjectFit,
galleryImageMinimumWidth,
shouldUseSingleGalleryColumn,
activeTabName,
isLightboxOpen,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
interface HoverableImageProps {
image: ImageDTO;
isSelected: boolean;
}
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const {
activeTabName,
galleryImageObjectFit,
galleryImageMinimumWidth,
canDeleteImage,
shouldUseSingleGalleryColumn,
} = useAppSelector(selector);
const { image, isSelected } = props;
const { image_url, thumbnail_url, image_name } = image;
const [isHovered, setIsHovered] = useState<boolean>(false);
const toaster = useAppToaster();
const { t } = useTranslation();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onDelete } = useContext(DeleteImageContext);
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
onDelete(image);
}, [image, onDelete]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const { attributes, listeners, setNodeRef } = useDraggable({
id: `galleryImage_${image_name}`,
data: {
image,
},
});
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image.image_name));
}, [image, dispatch]);
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
/**
* TODO: the rest of these
*/
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
if (activeTabName !== 'unifiedCanvas') {
dispatch(setActiveTab('unifiedCanvas'));
}
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
return (
<Box
ref={setNodeRef}
{...listeners}
{...attributes}
sx={{ w: 'full', h: 'full', touchAction: 'none' }}
>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={image?.metadata?.positive_conditioning === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.300' }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</MenuList>
)}
>
{(ref) => (
<Box
position="relative"
key={image_name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
onClick={handleSelectImage}
ref={ref}
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
w: 'full',
h: 'full',
transition: 'transform 0.2s ease-out',
aspectRatio: '1/1',
cursor: 'pointer',
}}
>
<Image
loading="lazy"
objectFit={
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
draggable={false}
rounded="md"
src={thumbnail_url || image_url}
fallback={<FaImage />}
sx={{
width: '100%',
height: '100%',
maxWidth: '100%',
maxHeight: '100%',
}}
/>
{isSelected && (
<Flex
sx={{
position: 'absolute',
top: '0',
insetInlineStart: '0',
width: '100%',
height: '100%',
alignItems: 'center',
justifyContent: 'center',
pointerEvents: 'none',
}}
>
<Icon
filter={'drop-shadow(0px 0px 1rem black)'}
as={FaCheck}
sx={{
width: '50%',
height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500',
}}
/>
</Flex>
)}
{isHovered && galleryImageMinimumWidth >= 100 && (
<Box
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
}}
>
<IAIIconButton
onClickCapture={handleDelete}
aria-label={t('gallery.deleteImage')}
icon={<FaTrash />}
size="xs"
fontSize={14}
isDisabled={!canDeleteImage}
/>
</Box>
)}
</Box>
)}
</ContextMenu>
</Box>
);
};
export default memo(HoverableImage);

View File

@ -0,0 +1,278 @@
import { MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback, useContext } from 'react';
import {
FaExpand,
FaFolder,
FaFolderPlus,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { createSelector } from '@reduxjs/toolkit';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { ImageDTO } from 'services/api/types';
import { RootState, stateSelector } from 'app/store/store';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
const selector = createSelector(
[stateSelector, (state: RootState, imageDTO: ImageDTO) => imageDTO],
({ gallery, batch }, imageDTO) => {
const selectionCount = gallery.selection.length;
const isInBatch = batch.imageNames.includes(imageDTO.image_name);
return { selectionCount, isInBatch };
},
defaultSelectorOptions
);
type Props = {
image: ImageDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const ImageContextMenu = ({ image, children }: Props) => {
const { selectionCount, isInBatch } = useAppSelector((state) =>
selector(state, image)
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
if (!image) {
return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
const handleAddSelectionToBatch = useCallback(() => {
dispatch(selectionAddedToBatch());
}, [dispatch]);
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([image.image_name]));
}, [dispatch, image.image_name]);
return (
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? (
<>
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
image?.metadata?.positive_conditioning === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
{/* <MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem> */}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
) : (
<>
<MenuItem
isDisabled={true}
icon={<FaFolder />}
onClickCapture={handleAddToBoard}
>
Move Selection to Board
</MenuItem>
{/* <MenuItem
icon={<FaFolderPlus />}
onClickCapture={handleAddSelectionToBatch}
>
Add Selection to Batch
</MenuItem> */}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
Delete Selection
</MenuItem>
</>
)}
</MenuList>
)}
>
{children}
</ContextMenu>
);
};
export default memo(ImageContextMenu);

View File

@ -5,7 +5,7 @@ import {
Flex, Flex,
FlexProps, FlexProps,
Grid, Grid,
Icon, Skeleton,
Text, Text,
VStack, VStack,
forwardRef, forwardRef,
@ -18,12 +18,8 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { import {
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
setGalleryView, setGalleryView,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
@ -42,77 +38,56 @@ import {
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
import { FaImage, FaServer, FaWrench } from 'react-icons/fa'; import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
import { MdPhotoLibrary } from 'react-icons/md'; import GalleryImage from './GalleryImage';
import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { import {
ASSETS_CATEGORIES, ASSETS_CATEGORIES,
IMAGE_CATEGORIES, IMAGE_CATEGORIES,
imageCategoriesChanged, imageCategoriesChanged,
selectImagesAll, shouldAutoSwitchChanged,
} from '../store/imagesSlice'; selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import BoardsList from './Boards/BoardsList'; import BoardsList from './Boards/BoardsList';
import { boardsSelector } from '../store/boardSlice';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import { ImageDTO } from 'services/api/types';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
const itemSelector = createSelector( const LOADING_IMAGE_ARRAY = Array(20).fill('loading');
[(state: RootState) => state],
(state) => {
const { categories, total: allImagesTotal, isLoading } = state.images;
const { selectedBoardId } = state.boards;
const allImages = selectImagesAll(state); const selector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const {
categories,
total: allImagesTotal,
isLoading,
selectedBoardId,
galleryImageMinimumWidth,
galleryView,
shouldAutoSwitch,
} = state.gallery;
const { shouldPinGallery } = state.ui;
const images = allImages.filter((i) => { const images = filteredImages as (ImageDTO | string)[];
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
return { return {
images, images: isLoading ? images.concat(LOADING_IMAGE_ARRAY) : images,
allImagesTotal, allImagesTotal,
isLoading, isLoading,
categories, categories,
selectedBoardId, selectedBoardId,
};
},
defaultSelectorOptions
);
const mainSelector = createSelector(
[gallerySelector, uiSelector, boardsSelector],
(gallery, ui, boards) => {
const {
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView,
} = gallery;
const { shouldPinGallery } = ui;
return {
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, shouldAutoSwitch,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView, galleryView,
selectedBoardId: boards.selectedBoardId,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -140,17 +115,16 @@ const ImageGalleryContent = () => {
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const { const {
images,
isLoading,
allImagesTotal,
categories,
selectedBoardId,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, shouldAutoSwitch,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView, galleryView,
} = useAppSelector(mainSelector); } = useAppSelector(selector);
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
useAppSelector(itemSelector);
const { selectedBoard } = useListAllBoardsQuery(undefined, { const { selectedBoard } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
@ -208,11 +182,14 @@ const ImageGalleryContent = () => {
return () => osInstance()?.destroy(); return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]); }, [scroller, initialize, osInstance]);
const setScrollerRef = useCallback((ref: HTMLElement | Window | null) => { useEffect(() => {
if (ref instanceof HTMLElement) { dispatch(
setScroller(ref); receivedPageOfImages({
} categories: ['general'],
}, []); is_intermediate: false,
})
);
}, [dispatch]);
const handleClickImagesCategory = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
@ -314,29 +291,11 @@ const ImageGalleryContent = () => {
withReset withReset
handleReset={() => dispatch(setGalleryImageMinimumWidth(64))} handleReset={() => dispatch(setGalleryImageMinimumWidth(64))}
/> />
<IAISimpleCheckbox
label={t('gallery.maintainAspectRatio')}
isChecked={galleryImageObjectFit === 'contain'}
onChange={() =>
dispatch(
setGalleryImageObjectFit(
galleryImageObjectFit === 'contain' ? 'cover' : 'contain'
)
)
}
/>
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('gallery.autoSwitchNewImages')} label={t('gallery.autoSwitchNewImages')}
isChecked={shouldAutoSwitchToNewImages} isChecked={shouldAutoSwitch}
onChange={(e: ChangeEvent<HTMLInputElement>) => onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAutoSwitchToNewImages(e.target.checked)) dispatch(shouldAutoSwitchChanged(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('gallery.singleColumnLayout')}
isChecked={shouldUseSingleGalleryColumn}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseSingleGalleryColumn(e.target.checked))
} }
/> />
</Flex> </Flex>
@ -358,23 +317,6 @@ const ImageGalleryContent = () => {
{images.length || areMoreAvailable ? ( {images.length || areMoreAvailable ? (
<> <>
<Box ref={rootRef} data-overlayscrollbars="" h="100%"> <Box ref={rootRef} data-overlayscrollbars="" h="100%">
{shouldUseSingleGalleryColumn ? (
<Virtuoso
style={{ height: '100%' }}
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, item) => (
<Flex sx={{ pb: 2 }}>
<HoverableImage
key={`${item.image_name}-${item.thumbnail_url}`}
image={item}
isSelected={selectedImage === item?.image_name}
/>
</Flex>
)}
/>
) : (
<VirtuosoGrid <VirtuosoGrid
style={{ height: '100%' }} style={{ height: '100%' }}
data={images} data={images}
@ -384,15 +326,19 @@ const ImageGalleryContent = () => {
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, item) => ( itemContent={(index, item) =>
<HoverableImage typeof item === 'string' ? (
<Skeleton
sx={{ w: 'full', h: 'full', aspectRatio: '1/1' }}
/>
) : (
<GalleryImage
key={`${item.image_name}-${item.thumbnail_url}`} key={`${item.image_name}-${item.thumbnail_url}`}
image={item} imageDTO={item}
isSelected={selectedImage === item?.image_name}
/> />
)} )
}
/> />
)}
</Box> </Box>
<IAIButton <IAIButton
onClick={handleLoadMoreImages} onClick={handleLoadMoreImages}
@ -407,27 +353,10 @@ const ImageGalleryContent = () => {
</IAIButton> </IAIButton>
</> </>
) : ( ) : (
<Flex <IAINoContentFallback
sx={{ label={t('gallery.noImagesInGallery')}
flexDirection: 'column', icon={FaImage}
alignItems: 'center',
justifyContent: 'center',
gap: 2,
padding: 8,
h: '100%',
w: '100%',
color: 'base.500',
}}
>
<Icon
as={MdPhotoLibrary}
sx={{
w: 16,
h: 16,
}}
/> />
<Text textAlign="center">{t('gallery.noImagesInGallery')}</Text>
</Flex>
)} )}
</Flex> </Flex>
</VStack> </VStack>
@ -436,7 +365,7 @@ const ImageGalleryContent = () => {
type ItemContainerProps = PropsWithChildren & FlexProps; type ItemContainerProps = PropsWithChildren & FlexProps;
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => ( const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
<Box className="item-container" ref={ref}> <Box className="item-container" ref={ref} p={1.5}>
{props.children} {props.children}
</Box> </Box>
)); ));
@ -453,8 +382,7 @@ const ListContainer = forwardRef((props: ListContainerProps, ref) => {
className="list-container" className="list-container"
ref={ref} ref={ref}
sx={{ sx={{
gap: 2, gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`,
gridTemplateColumns: `repeat(auto-fit, minmax(${galleryImageMinimumWidth}px, 1fr));`,
}} }}
> >
{props.children} {props.children}

View File

@ -5,14 +5,13 @@ import { clamp, isEqual } from 'lodash-es';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaAngleLeft, FaAngleRight } from 'react-icons/fa'; import { FaAngleLeft, FaAngleRight } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { stateSelector } from 'app/store/store';
import { RootState } from 'app/store/store';
import { imageSelected } from '../store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook';
import { import {
selectFilteredImagesAsObject, imageSelected,
selectFilteredImagesIds, selectImagesById,
} from '../store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook';
import { selectFilteredImages } from 'features/gallery/store/gallerySlice';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%', height: '100%',
@ -25,45 +24,40 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
}; };
export const nextPrevImageButtonsSelector = createSelector( export const nextPrevImageButtonsSelector = createSelector(
[ [stateSelector, selectFilteredImages],
(state: RootState) => state, (state, filteredImages) => {
gallerySelector, const lastSelectedImage =
selectFilteredImagesAsObject, state.gallery.selection[state.gallery.selection.length - 1];
selectFilteredImagesIds,
],
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
const { selectedImage } = gallery;
if (!selectedImage) { if (!lastSelectedImage || filteredImages.length === 0) {
return { return {
isOnFirstImage: true, isOnFirstImage: true,
isOnLastImage: true, isOnLastImage: true,
}; };
} }
const currentImageIndex = filteredImageIds.findIndex( const currentImageIndex = filteredImages.findIndex(
(i) => i === selectedImage (i) => i.image_name === lastSelectedImage
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(
currentImageIndex + 1, currentImageIndex + 1,
0, 0,
filteredImageIds.length - 1 filteredImages.length - 1
); );
const prevImageIndex = clamp( const prevImageIndex = clamp(
currentImageIndex - 1, currentImageIndex - 1,
0, 0,
filteredImageIds.length - 1 filteredImages.length - 1
); );
const nextImageId = filteredImageIds[nextImageIndex]; const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImageIds[prevImageIndex]; const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = filteredImagesAsObject[nextImageId]; const nextImage = selectImagesById(state, nextImageId);
const prevImage = filteredImagesAsObject[prevImageId]; const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImageIds.length; const imagesLength = filteredImages.length;
return { return {
isOnFirstImage: currentImageIndex === 0, isOnFirstImage: currentImageIndex === 0,
@ -101,11 +95,11 @@ const NextPrevImageButtons = () => {
}, []); }, []);
const handlePrevImage = useCallback(() => { const handlePrevImage = useCallback(() => {
dispatch(imageSelected(prevImageId)); prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]); }, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => { const handleNextImage = useCallback(() => {
dispatch(imageSelected(nextImageId)); nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]); }, [dispatch, nextImageId]);
useHotkeys( useHotkeys(

View File

@ -1,40 +0,0 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { mode } from 'theme/util/mode';
export const SelectedItemOverlay = () => {
const [accent400, accent500] = useToken('colors', [
'accent.400',
'accent.500',
]);
const { colorMode } = useColorMode();
return (
<motion.div
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
width: '100%',
height: '100%',
boxShadow: `inset 0px 0px 0px 2px ${mode(
accent400,
accent500
)(colorMode)}`,
borderRadius: 'var(--invokeai-radii-base)',
}}
/>
);
};

View File

@ -1,18 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectImagesEntities } from '../store/imagesSlice';
import { useCallback } from 'react';
const useGetImageByName = () => {
const images = useAppSelector(selectImagesEntities);
return useCallback(
(name: string | undefined) => {
if (!name) {
return;
}
return images[name];
},
[images]
);
};
export default useGetImageByName;

View File

@ -1,15 +1,6 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageUsage } from 'app/contexts/DeleteImageContext'; import { ImageUsage } from 'app/contexts/AddImageToBoardContext';
import { ImageDTO, BoardDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
export type RequestedImageDeletionArg = {
image: ImageDTO;
imageUsage: ImageUsage;
};
export const requestedImageDeletion = createAction<RequestedImageDeletionArg>(
'gallery/requestedImageDeletion'
);
export type RequestedBoardImagesDeletionArg = { export type RequestedBoardImagesDeletionArg = {
board: BoardDTO; board: BoardDTO;

View File

@ -1,10 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { boardsApi } from 'services/api/endpoints/boards';
type BoardsState = { type BoardsState = {
searchText: string; searchText: string;
selectedBoardId?: string;
updateBoardModalOpen: boolean; updateBoardModalOpen: boolean;
}; };
@ -17,9 +15,6 @@ const boardsSlice = createSlice({
name: 'boards', name: 'boards',
initialState: initialBoardsState, initialState: initialBoardsState,
reducers: { reducers: {
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedBoardId = action.payload;
},
setBoardSearchText: (state, action: PayloadAction<string>) => { setBoardSearchText: (state, action: PayloadAction<string>) => {
state.searchText = action.payload; state.searchText = action.payload;
}, },
@ -27,19 +22,9 @@ const boardsSlice = createSlice({
state.updateBoardModalOpen = action.payload; state.updateBoardModalOpen = action.payload;
}, },
}, },
extraReducers: (builder) => {
builder.addMatcher(
boardsApi.endpoints.deleteBoard.matchFulfilled,
(state, action) => {
if (action.meta.arg.originalArgs === state.selectedBoardId) {
state.selectedBoardId = undefined;
}
}
);
},
}); });
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } = export const { setBoardSearchText, setUpdateBoardModalOpen } =
boardsSlice.actions; boardsSlice.actions;
export const boardsSelector = (state: RootState) => state.boards; export const boardsSelector = (state: RootState) => state.boards;

View File

@ -1,8 +1,15 @@
import { GalleryState } from './gallerySlice'; import { initialGalleryState } from './gallerySlice';
/** /**
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
export const galleryPersistDenylist: (keyof GalleryState)[] = [ export const galleryPersistDenylist: (keyof typeof initialGalleryState)[] = [
'shouldAutoSwitchToNewImages', 'selection',
'entities',
'ids',
'isLoading',
'limit',
'offset',
'selectedBoardId',
'total',
]; ];

Some files were not shown because too many files have changed in this diff Show More