mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into doc_updates_23
This commit is contained in:
commit
b250d1ec86
2
.gitignore
vendored
2
.gitignore
vendored
@ -201,8 +201,6 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/web/dist/*
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
@ -2,17 +2,17 @@
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelRequest(BaseModel):
|
||||
name: str = Field(description="A model path, repo_id or URL to import")
|
||||
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
@ -86,7 +89,6 @@ async def list_models(
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
@ -109,27 +111,38 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
model_request: ImportModelRequest
|
||||
) -> None:
|
||||
""" Add Model """
|
||||
items_to_import = set([model_request.name])
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if len(installed_models) > 0:
|
||||
logger.info(f'Successfully imported {model_request.name}')
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {model_request.name} not imported')
|
||||
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
||||
get_type_hints)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseConfig, BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -65,7 +66,12 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
|
||||
return dict(
|
||||
map(
|
||||
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
||||
BaseInvocation.get_all_subclasses(),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
@ -76,10 +82,10 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
@ -97,16 +103,20 @@ class UIConfig(TypedDict, total=False):
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
"image_collection",
|
||||
"vae_model",
|
||||
"lora_model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
@ -4,13 +4,16 @@ from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from invokeai.app.models.image import ImageField
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
UIConfig,
|
||||
)
|
||||
|
||||
|
||||
@ -22,6 +25,7 @@ class IntCollectionOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
@ -31,6 +35,18 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of images"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@ -92,3 +108,27 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""Load a collection of images and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
images: list[ImageField] = Field(
|
||||
default=[], description="The image collection to load"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"images": "image_collection",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -1,27 +1,28 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from contextlib import ExitStack
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
import torch
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt, Fragment)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment, Conjunction,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .model import ClipField
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
conditioning_name: Optional[str] = Field(
|
||||
default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
@ -51,83 +52,92 @@ class CompelInvocation(BaseInvocation):
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||
prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
# if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(
|
||||
tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get(
|
||||
"cross_attention_control", None),)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
@ -146,13 +156,13 @@ def get_max_token_count(
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
get_tokens_for_prompt_object(
|
||||
tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
@ -181,7 +191,7 @@ def log_tokenization_for_conjunction(
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts)>1:
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
@ -236,7 +246,8 @@ def log_tokenization_for_prompt_object(
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(
|
||||
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
@ -4,18 +4,17 @@ from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
latents_name: Optional[str] = Field(
|
||||
default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
@ -70,14 +72,17 @@ def get_scheduler(
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler_config = {**scheduler_config, **
|
||||
scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
def get_conditioning_data(
|
||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(
|
||||
self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(
|
||||
self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
@ -164,20 +172,21 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, #ddim_eta
|
||||
eta=0.0, # ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
def create_pipeline(
|
||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
#configure_model_padding(
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
#)
|
||||
# )
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
@ -188,7 +197,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
@ -202,7 +211,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
# really only need model for dtype and device
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(control_name,
|
||||
subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_name, subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(
|
||||
model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
input_image = context.services.images.get_pil_image(
|
||||
control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -263,29 +275,40 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet:
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -296,8 +319,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to use as a base image")
|
||||
strength: float = Field(
|
||||
default=0.7, ge=0, le=1,
|
||||
description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
with unet_info as unet:
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
latent, device=unet.device, dtype=latent.dtype)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
device=unet.device,
|
||||
)
|
||||
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
||||
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to resize")
|
||||
width: int = Field(
|
||||
ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(
|
||||
ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to scale")
|
||||
scale_factor: float = Field(
|
||||
gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
|
@ -1,31 +1,38 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(
|
||||
default=None, description="Info to load submodel"
|
||||
)
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
@ -34,43 +41,48 @@ class VaeField(BaseModel):
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora_name: str = Field(description="Lora model name")
|
||||
lora: Union[LoRAModelField, None] = Field(
|
||||
default=None, description="Lora model name"
|
||||
)
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: ui rewrite
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
||||
if self.unet is not None and any(
|
||||
lora.model_name == lora_name for lora in self.unet.loras
|
||||
):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
||||
if self.clip is not None and any(
|
||||
lora.model_name == lora_name for lora in self.clip.loras
|
||||
):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=self.lora_name,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
type: Literal["vae_loader"] = "vae_loader"
|
||||
|
||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
"type_hints": {"vae_model": "vae_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VaeLoaderOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -367,7 +367,8 @@ setting environment variables INVOKEAI_<setting>.
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
|
@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.restoration_services import RestorationServices
|
||||
@ -22,46 +22,47 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
events: "EventServiceBase"
|
||||
latents: "LatentsStorageBase"
|
||||
queue: "InvocationQueueABC"
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAISettings"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
queue: "InvocationQueueABC"
|
||||
restoration: "RestorationServices"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: "ModelManager",
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: "RestorationServices",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAISettings",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
queue: "InvocationQueueABC",
|
||||
restoration: "RestorationServices",
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.boards = boards
|
||||
self.board_images = board_images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.graph_library = graph_library
|
||||
self.images = images
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.queue = queue
|
||||
self.restoration = restoration
|
||||
|
@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
@ -183,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
@ -361,3 +386,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
@ -430,13 +430,13 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
max_height=len(PRECISION_CHOICES) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_loaded_models = self.add_widget_intelligent(
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
IntTitleSlider,
|
||||
name="Number of models to cache in CPU memory (each will use 2-4 GB!)",
|
||||
value=old_opts.max_loaded_models,
|
||||
out_of=10,
|
||||
lowest=1,
|
||||
begin_entry_at=4,
|
||||
name="Size of the RAM cache used for fast model switching (GB)",
|
||||
value=old_opts.max_cache_size,
|
||||
out_of=20,
|
||||
lowest=3,
|
||||
begin_entry_at=6,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@ -539,7 +539,7 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"max_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
@ -555,9 +555,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
|
||||
|
@ -4,6 +4,8 @@ import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
# note that this includes both old sampler names and new scheduler names
|
||||
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
@ -27,6 +29,15 @@ SAMPLER_CHOICES = [
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
|
@ -76,6 +76,10 @@ class MigrateTo3(object):
|
||||
Create a unique name for a model for use within models.yaml.
|
||||
'''
|
||||
done = False
|
||||
|
||||
# some model names have slashes in them, which really screws things up
|
||||
name = name.replace('/','_')
|
||||
|
||||
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
||||
unique_name = key
|
||||
counter = 1
|
||||
@ -219,11 +223,12 @@ class MigrateTo3(object):
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
self._migrate_pretrained(CLIPTokenizer,
|
||||
repo_id= repo_id,
|
||||
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
||||
dest= target_dir / 'clip-vit-large-patch14',
|
||||
**kwargs)
|
||||
self._migrate_pretrained(CLIPTextModel,
|
||||
repo_id = repo_id,
|
||||
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
||||
dest = target_dir / 'clip-vit-large-patch14',
|
||||
force = True,
|
||||
**kwargs)
|
||||
|
||||
# sd-2
|
||||
@ -287,21 +292,21 @@ class MigrateTo3(object):
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
|
||||
if dest.exists():
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
|
||||
if dest.exists() and not force:
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest)
|
||||
self._save_pretrained(model, dest, overwrite=force)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path):
|
||||
if dest.exists():
|
||||
logger.info(f'Skipping existing {dest}')
|
||||
return
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
|
||||
model_name = dest.name
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
if overwrite:
|
||||
model.save_pretrained(dest, safe_serialization=True)
|
||||
else:
|
||||
download_path = dest.with_name(f'{model_name}.downloading')
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||
@ -569,8 +574,10 @@ script, which will perform a full upgrade in place."""
|
||||
|
||||
dest_directory = args.dest_directory
|
||||
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
|
||||
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
|
||||
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
|
||||
|
||||
# TODO: revisit
|
||||
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
|
||||
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
|
||||
|
||||
do_migrate(root_directory,dest_directory)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from ..util.logging import InvokeAILogger
|
||||
@ -166,17 +166,22 @@ class ModelInstall(object):
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_install(path)
|
||||
self.heuristic_import(path)
|
||||
job += 1
|
||||
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_install(self,
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Set[Path]:
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
|
||||
if not models_installed:
|
||||
models_installed = set()
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
@ -185,24 +190,27 @@ class ModelInstall(object):
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.add(self._install_path(path))
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
models_installed.add(self._install_path(path))
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_install(child, models_installed=models_installed)
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.add(self._install_repo(str(path)))
|
||||
models_installed.update(self._install_repo(str(path)))
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.add(self._install_url(model_path_id_or_url))
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
@ -214,24 +222,25 @@ class ModelInstall(object):
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
# logger.debug(f'Probing {path}')
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return path
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
|
||||
def _install_url(self, url: str)->Path:
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
@ -244,7 +253,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->Path:
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
|
||||
|
@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache
|
||||
from picklescan.scanner import scan_file_path
|
||||
from .models import BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
else:
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
|
@ -1,18 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -124,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
@ -411,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@ -539,9 +536,10 @@ class ModelPatcher:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# enable autocast to calc fp16 loras on cpu
|
||||
with torch.autocast(device_type="cpu"):
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
#with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@ -655,6 +653,9 @@ class TextualInversionModel:
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if len(result.embedding.shape) == 1:
|
||||
result.embedding = result.embedding.unsqueeze(0)
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
|
@ -8,7 +8,7 @@ The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_models_cached=6)
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
@ -91,7 +91,7 @@ class ModelCache(object):
|
||||
logger: types.ModuleType = logger
|
||||
):
|
||||
'''
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
@ -100,8 +100,6 @@ class ModelCache(object):
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
#max_cache_size = 9999
|
||||
execution_device = torch.device('cuda')
|
||||
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
@ -128,16 +126,6 @@ class ModelCache(object):
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
#def get_model(
|
||||
# self,
|
||||
# repo_id_or_path: Union[str, Path],
|
||||
# model_type: ModelType = ModelType.Diffusers,
|
||||
# subfolder: Path = None,
|
||||
# submodel: ModelType = None,
|
||||
# revision: str = None,
|
||||
# attach_model_part: Tuple[ModelType, str] = (None, None),
|
||||
# gpu_load: bool = True,
|
||||
#) -> ModelLocker: # ?? what does it return
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
|
@ -233,14 +233,14 @@ import hashlib
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
ModelConfigBase,
|
||||
ModelConfigBase, ModelNotFoundException,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after import")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
class ConfigMeta(BaseModel):
|
||||
version: str
|
||||
@ -404,7 +409,7 @@ class ModelManager(object):
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.app_config.root_path / model_config.path
|
||||
@ -416,14 +421,14 @@ class ModelManager(object):
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = override_path
|
||||
model_path = self.app_config.root_path / override_path
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -431,6 +436,7 @@ class ModelManager(object):
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
dst_convert_path = self._get_model_cache_path(model_path)
|
||||
|
||||
model_path = model_class.convert_if_required(
|
||||
base_model=base_model,
|
||||
model_path=str(model_path), # TODO: refactor str/Path types logic
|
||||
@ -570,13 +576,16 @@ class ModelManager(object):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
|
||||
The returned dict has the same format as the dict returned by
|
||||
model_info().
|
||||
"""
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -600,12 +609,18 @@ class ModelManager(object):
|
||||
old_model_cache.unlink()
|
||||
|
||||
# remove in-memory cache
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
# note: it not guaranteed to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
base_model = base_model,
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
@ -716,19 +731,19 @@ class ModelManager(object):
|
||||
|
||||
if model_path.is_relative_to(self.app_config.root_path):
|
||||
model_path = model_path.relative_to(self.app_config.root_path)
|
||||
try:
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
try:
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
imported_models = self.autoimport()
|
||||
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
def autoimport(self)->set[Path]:
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
'''
|
||||
@ -741,7 +756,6 @@ class ModelManager(object):
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
installed = set()
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
@ -755,13 +769,14 @@ class ModelManager(object):
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = set()
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
@ -770,8 +785,8 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
|
||||
for f in files:
|
||||
@ -779,7 +794,8 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
@ -789,7 +805,7 @@ class ModelManager(object):
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Set[str]:
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -802,17 +818,20 @@ class ModelManager(object):
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
successfully_installed = set()
|
||||
successfully_installed = dict()
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_install(thing)
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
@ -78,7 +78,6 @@ class ModelProbe(object):
|
||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
||||
) else 512,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type")
|
||||
raise ValueError("Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
|
@ -8,6 +8,7 @@ from .base import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
|
||||
checkpoint_path = self.model_path
|
||||
if os.path.isdir(checkpoint_path):
|
||||
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
file_path=checkpoint_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
|
@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
# needed because the torch library is loaded, even though we don't use it
|
||||
# currently commented out because it has started generating errors (?)
|
||||
# torch.multiprocessing.set_start_method("spawn")
|
||||
# needed to support the probe() method running under a subprocess
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
# the third argument is needed in the Windows 11 environment in
|
||||
# order to launch and resize a console window running this program
|
||||
|
@ -36,6 +36,12 @@ module.exports = {
|
||||
],
|
||||
'prettier/prettier': ['error', { endOfLine: 'auto' }],
|
||||
'@typescript-eslint/ban-ts-comment': 'warn',
|
||||
'@typescript-eslint/no-empty-interface': [
|
||||
'error',
|
||||
{
|
||||
allowSingleExtends: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
settings: {
|
||||
react: {
|
||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
17
invokeai/frontend/web/dist/locales/en.json
vendored
17
invokeai/frontend/web/dist/locales/en.json
vendored
@ -24,16 +24,13 @@
|
||||
},
|
||||
"common": {
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"themeLabel": "Theme",
|
||||
"darkMode": "Dark Mode",
|
||||
"lightMode": "Light Mode",
|
||||
"languagePickerLabel": "Language",
|
||||
"reportBugLabel": "Report Bug",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
"settingsLabel": "Settings",
|
||||
"darkTheme": "Dark",
|
||||
"lightTheme": "Light",
|
||||
"greenTheme": "Green",
|
||||
"oceanTheme": "Ocean",
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
@ -55,6 +52,7 @@
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"modelmanager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@ -336,6 +334,7 @@
|
||||
"modelManager": {
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"vae": "VAE",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
@ -351,6 +350,7 @@
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"baseModel": "Base Model",
|
||||
"name": "Name",
|
||||
"nameValidationMsg": "Enter a name for your model",
|
||||
"description": "Description",
|
||||
@ -363,6 +363,7 @@
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"variant": "Variant",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
@ -524,7 +525,8 @@
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
@ -547,7 +549,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
@ -67,6 +67,7 @@
|
||||
"@fontsource-variable/inter": "^5.0.3",
|
||||
"@fontsource/inter": "^5.0.3",
|
||||
"@mantine/core": "^6.0.14",
|
||||
"@mantine/form": "^6.0.15",
|
||||
"@mantine/hooks": "^6.0.14",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
@ -82,7 +83,7 @@
|
||||
"konva": "^9.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanostores": "^0.9.2",
|
||||
"openapi-fetch": "^0.4.0",
|
||||
"openapi-fetch": "0.4.0",
|
||||
"overlayscrollbars": "^2.2.0",
|
||||
"overlayscrollbars-react": "^0.5.0",
|
||||
"patch-package": "^7.0.0",
|
||||
|
@ -52,6 +52,8 @@
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"batch": "Batch Manager",
|
||||
"modelmanager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@ -333,6 +335,7 @@
|
||||
"modelManager": {
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"vae": "VAE",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
@ -348,6 +351,7 @@
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"baseModel": "Base Model",
|
||||
"name": "Name",
|
||||
"nameValidationMsg": "Enter a name for your model",
|
||||
"description": "Description",
|
||||
@ -360,6 +364,7 @@
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"variant": "Variant",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
|
@ -1,67 +1,40 @@
|
||||
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import i18n from 'i18n';
|
||||
import { ReactNode, memo, useCallback, useEffect, useState } from 'react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import { ReactNode, memo, useEffect } from 'react';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
}
|
||||
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
}: Props) => {
|
||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const log = useLogger();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
|
||||
const isApplicationReady = useIsApplicationReady();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
const { data: controlnetModels } = useListModelsQuery({
|
||||
model_type: 'controlnet',
|
||||
});
|
||||
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||
const { data: embeddingModels } = useListModelsQuery({
|
||||
model_type: 'embedding',
|
||||
});
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
@ -73,27 +46,6 @@ const App = ({
|
||||
dispatch(configChanged(config));
|
||||
}, [dispatch, config, log]);
|
||||
|
||||
const handleOverrideClicked = useCallback(() => {
|
||||
setLoadingOverridden(true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (isApplicationReady && setIsReady) {
|
||||
setIsReady(true);
|
||||
}
|
||||
|
||||
if (isApplicationReady) {
|
||||
// TODO: This is a jank fix for canvas not filling the screen on first load
|
||||
setTimeout(() => {
|
||||
dispatch(requestCanvasRescale());
|
||||
}, 200);
|
||||
}
|
||||
|
||||
return () => {
|
||||
setIsReady && setIsReady(false);
|
||||
};
|
||||
}, [dispatch, isApplicationReady, setIsReady]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
@ -123,33 +75,6 @@ const App = ({
|
||||
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
key="loading"
|
||||
initial={{ opacity: 1 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ zIndex: 3 }}
|
||||
>
|
||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
||||
<Loading />
|
||||
</Box>
|
||||
<Box
|
||||
onClick={handleOverrideClicked}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
cursor="pointer"
|
||||
w="2rem"
|
||||
h="2rem"
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
|
@ -0,0 +1,122 @@
|
||||
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { memo } from 'react';
|
||||
import { TypesafeDraggableData } from './typesafeDnd';
|
||||
|
||||
type OverlayDragImageProps = {
|
||||
dragData: TypesafeDraggableData | null;
|
||||
};
|
||||
|
||||
const BOX_SIZE = 28;
|
||||
|
||||
const STYLES: ChakraProps['sx'] = {
|
||||
w: BOX_SIZE,
|
||||
h: BOX_SIZE,
|
||||
maxW: BOX_SIZE,
|
||||
maxH: BOX_SIZE,
|
||||
shadow: 'dark-lg',
|
||||
borderRadius: 'lg',
|
||||
borderWidth: 2,
|
||||
borderStyle: 'dashed',
|
||||
borderColor: 'base.100',
|
||||
opacity: 0.5,
|
||||
bg: 'base.800',
|
||||
color: 'base.50',
|
||||
_dark: {
|
||||
borderColor: 'base.200',
|
||||
bg: 'base.900',
|
||||
color: 'base.100',
|
||||
},
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const gallerySelectionCount = state.gallery.selection.length;
|
||||
const batchSelectionCount = state.batch.selection.length;
|
||||
|
||||
return {
|
||||
gallerySelectionCount,
|
||||
batchSelectionCount,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const DragPreview = (props: OverlayDragImageProps) => {
|
||||
const { gallerySelectionCount, batchSelectionCount } =
|
||||
useAppSelector(selector);
|
||||
|
||||
if (!props.dragData) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'none',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
sx={{
|
||||
...STYLES,
|
||||
}}
|
||||
src={props.dragData.payload.imageDTO.thumbnail_url}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'BATCH_SELECTION') {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
cursor: 'none',
|
||||
userSelect: 'none',
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
flexDir: 'column',
|
||||
...STYLES,
|
||||
}}
|
||||
>
|
||||
<Heading>{batchSelectionCount}</Heading>
|
||||
<Heading size="sm">Images</Heading>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
cursor: 'none',
|
||||
userSelect: 'none',
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
flexDir: 'column',
|
||||
...STYLES,
|
||||
}}
|
||||
>
|
||||
<Heading>{gallerySelectionCount}</Heading>
|
||||
<Heading size="sm">Images</Heading>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(DragPreview);
|
@ -1,8 +1,5 @@
|
||||
import {
|
||||
DndContext,
|
||||
DragEndEvent,
|
||||
DragOverlay,
|
||||
DragStartEvent,
|
||||
MouseSensor,
|
||||
TouchSensor,
|
||||
pointerWithin,
|
||||
@ -10,33 +7,45 @@ import {
|
||||
useSensors,
|
||||
} from '@dnd-kit/core';
|
||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||
import OverlayDragImage from './OverlayDragImage';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { isImageDTO } from 'services/api/guards';
|
||||
import DragPreview from './DragPreview';
|
||||
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import {
|
||||
DndContext,
|
||||
DragEndEvent,
|
||||
DragStartEvent,
|
||||
TypesafeDraggableData,
|
||||
} from './typesafeDnd';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { imageDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
|
||||
|
||||
type ImageDndContextProps = PropsWithChildren;
|
||||
|
||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
const [draggedImage, setDraggedImage] = useState<ImageDTO | null>(null);
|
||||
const [activeDragData, setActiveDragData] =
|
||||
useState<TypesafeDraggableData | null>(null);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleDragStart = useCallback((event: DragStartEvent) => {
|
||||
const dragData = event.active.data.current;
|
||||
if (dragData && 'image' in dragData && isImageDTO(dragData.image)) {
|
||||
setDraggedImage(dragData.image);
|
||||
const activeData = event.active.data.current;
|
||||
if (!activeData) {
|
||||
return;
|
||||
}
|
||||
setActiveDragData(activeData);
|
||||
}, []);
|
||||
|
||||
const handleDragEnd = useCallback(
|
||||
(event: DragEndEvent) => {
|
||||
const handleDrop = event.over?.data.current?.handleDrop;
|
||||
if (handleDrop && typeof handleDrop === 'function' && draggedImage) {
|
||||
handleDrop(draggedImage);
|
||||
const activeData = event.active.data.current;
|
||||
const overData = event.over?.data.current;
|
||||
if (!activeData || !overData) {
|
||||
return;
|
||||
}
|
||||
setDraggedImage(null);
|
||||
dispatch(imageDropped({ overData, activeData }));
|
||||
setActiveDragData(null);
|
||||
},
|
||||
[draggedImage]
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const mouseSensor = useSensor(MouseSensor, {
|
||||
@ -46,6 +55,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
const touchSensor = useSensor(TouchSensor, {
|
||||
activationConstraint: { delay: 150, tolerance: 5 },
|
||||
});
|
||||
|
||||
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
|
||||
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
|
||||
// (currently the drag element collision rect is not correctly calculated)
|
||||
@ -63,7 +73,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
{props.children}
|
||||
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
|
||||
<AnimatePresence>
|
||||
{draggedImage && (
|
||||
{activeDragData && (
|
||||
<motion.div
|
||||
layout
|
||||
key="overlay-drag-image"
|
||||
@ -77,7 +87,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
>
|
||||
<OverlayDragImage image={draggedImage} />
|
||||
<DragPreview dragData={activeDragData} />
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
@ -1,36 +0,0 @@
|
||||
import { Box, Image } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type OverlayDragImageProps = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
const OverlayDragImage = (props: OverlayDragImageProps) => {
|
||||
return (
|
||||
<Box
|
||||
style={{
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'grabbing',
|
||||
opacity: 0.5,
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
sx={{
|
||||
maxW: 36,
|
||||
maxH: 36,
|
||||
borderRadius: 'base',
|
||||
shadow: 'dark-lg',
|
||||
}}
|
||||
src={props.image.thumbnail_url}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(OverlayDragImage);
|
@ -0,0 +1,201 @@
|
||||
// type-safe dnd from https://github.com/clauderic/dnd-kit/issues/935
|
||||
import {
|
||||
Active,
|
||||
Collision,
|
||||
DndContextProps,
|
||||
DndContext as OriginalDndContext,
|
||||
Over,
|
||||
Translate,
|
||||
UseDraggableArguments,
|
||||
UseDroppableArguments,
|
||||
useDraggable as useOriginalDraggable,
|
||||
useDroppable as useOriginalDroppable,
|
||||
} from '@dnd-kit/core';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type BaseDropData = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type CurrentImageDropData = BaseDropData & {
|
||||
actionType: 'SET_CURRENT_IMAGE';
|
||||
};
|
||||
|
||||
export type InitialImageDropData = BaseDropData & {
|
||||
actionType: 'SET_INITIAL_IMAGE';
|
||||
};
|
||||
|
||||
export type ControlNetDropData = BaseDropData & {
|
||||
actionType: 'SET_CONTROLNET_IMAGE';
|
||||
context: {
|
||||
controlNetId: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type CanvasInitialImageDropData = BaseDropData & {
|
||||
actionType: 'SET_CANVAS_INITIAL_IMAGE';
|
||||
};
|
||||
|
||||
export type NodesImageDropData = BaseDropData & {
|
||||
actionType: 'SET_NODES_IMAGE';
|
||||
context: {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type NodesMultiImageDropData = BaseDropData & {
|
||||
actionType: 'SET_MULTI_NODES_IMAGE';
|
||||
context: { nodeId: string; fieldName: string };
|
||||
};
|
||||
|
||||
export type AddToBatchDropData = BaseDropData & {
|
||||
actionType: 'ADD_TO_BATCH';
|
||||
};
|
||||
|
||||
export type MoveBoardDropData = BaseDropData & {
|
||||
actionType: 'MOVE_BOARD';
|
||||
context: { boardId: string | null };
|
||||
};
|
||||
|
||||
export type TypesafeDroppableData =
|
||||
| CurrentImageDropData
|
||||
| InitialImageDropData
|
||||
| ControlNetDropData
|
||||
| CanvasInitialImageDropData
|
||||
| NodesImageDropData
|
||||
| AddToBatchDropData
|
||||
| NodesMultiImageDropData
|
||||
| MoveBoardDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type ImageDraggableData = BaseDragData & {
|
||||
payloadType: 'IMAGE_DTO';
|
||||
payload: { imageDTO: ImageDTO };
|
||||
};
|
||||
|
||||
export type GallerySelectionDraggableData = BaseDragData & {
|
||||
payloadType: 'GALLERY_SELECTION';
|
||||
};
|
||||
|
||||
export type BatchSelectionDraggableData = BaseDragData & {
|
||||
payloadType: 'BATCH_SELECTION';
|
||||
};
|
||||
|
||||
export type TypesafeDraggableData =
|
||||
| ImageDraggableData
|
||||
| GallerySelectionDraggableData
|
||||
| BatchSelectionDraggableData;
|
||||
|
||||
interface UseDroppableTypesafeArguments
|
||||
extends Omit<UseDroppableArguments, 'data'> {
|
||||
data?: TypesafeDroppableData;
|
||||
}
|
||||
|
||||
type UseDroppableTypesafeReturnValue = Omit<
|
||||
ReturnType<typeof useOriginalDroppable>,
|
||||
'active' | 'over'
|
||||
> & {
|
||||
active: TypesafeActive | null;
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
export function useDroppable(props: UseDroppableTypesafeArguments) {
|
||||
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
|
||||
}
|
||||
|
||||
interface UseDraggableTypesafeArguments
|
||||
extends Omit<UseDraggableArguments, 'data'> {
|
||||
data?: TypesafeDraggableData;
|
||||
}
|
||||
|
||||
type UseDraggableTypesafeReturnValue = Omit<
|
||||
ReturnType<typeof useOriginalDraggable>,
|
||||
'active' | 'over'
|
||||
> & {
|
||||
active: TypesafeActive | null;
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
export function useDraggable(props: UseDraggableTypesafeArguments) {
|
||||
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
|
||||
}
|
||||
|
||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||
}
|
||||
|
||||
interface TypesafeOver extends Omit<Over, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
|
||||
}
|
||||
|
||||
export const isValidDrop = (
|
||||
overData: TypesafeDroppableData | undefined,
|
||||
active: TypesafeActive | null
|
||||
) => {
|
||||
if (!overData || !active?.data.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { actionType } = overData;
|
||||
const { payloadType } = active.data.current;
|
||||
|
||||
if (overData.id === active.data.current.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (actionType) {
|
||||
case 'SET_CURRENT_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CONTROLNET_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_MULTI_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
|
||||
case 'ADD_TO_BATCH':
|
||||
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
|
||||
case 'MOVE_BOARD':
|
||||
return (
|
||||
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
|
||||
);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
interface DragEvent {
|
||||
activatorEvent: Event;
|
||||
active: TypesafeActive;
|
||||
collisions: Collision[] | null;
|
||||
delta: Translate;
|
||||
over: TypesafeOver | null;
|
||||
}
|
||||
|
||||
export interface DragStartEvent extends Pick<DragEvent, 'active'> {}
|
||||
export interface DragMoveEvent extends DragEvent {}
|
||||
export interface DragOverEvent extends DragMoveEvent {}
|
||||
export interface DragEndEvent extends DragEvent {}
|
||||
export interface DragCancelEvent extends DragEndEvent {}
|
||||
|
||||
export interface DndContextTypesafeProps
|
||||
extends Omit<
|
||||
DndContextProps,
|
||||
'onDragStart' | 'onDragMove' | 'onDragOver' | 'onDragEnd' | 'onDragCancel'
|
||||
> {
|
||||
onDragStart?(event: DragStartEvent): void;
|
||||
onDragMove?(event: DragMoveEvent): void;
|
||||
onDragOver?(event: DragOverEvent): void;
|
||||
onDragEnd?(event: DragEndEvent): void;
|
||||
onDragCancel?(event: DragCancelEvent): void;
|
||||
}
|
||||
export function DndContext(props: DndContextTypesafeProps) {
|
||||
return <OriginalDndContext {...props} />;
|
||||
}
|
@ -7,7 +7,6 @@ import React, {
|
||||
} from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { store } from 'app/store/store';
|
||||
// import { OpenAPI } from 'services/api/types';
|
||||
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
@ -17,11 +16,6 @@ import '../../i18n';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||
import {
|
||||
DeleteImageContext,
|
||||
DeleteImageContextProvider,
|
||||
} from 'app/contexts/DeleteImageContext';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext';
|
||||
@ -34,7 +28,6 @@ interface Props extends PropsWithChildren {
|
||||
token?: string;
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
middleware?: Middleware[];
|
||||
}
|
||||
|
||||
@ -43,7 +36,6 @@ const InvokeAIUI = ({
|
||||
token,
|
||||
config,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
middleware,
|
||||
}: Props) => {
|
||||
useEffect(() => {
|
||||
@ -85,17 +77,11 @@ const InvokeAIUI = ({
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<DeleteImageContextProvider>
|
||||
<AddImageToBoardContextProvider>
|
||||
<DeleteBoardImagesContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</DeleteBoardImagesContextProvider>
|
||||
</AddImageToBoardContextProvider>
|
||||
</DeleteImageContextProvider>
|
||||
<AddImageToBoardContextProvider>
|
||||
<DeleteBoardImagesContextProvider>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
</DeleteBoardImagesContextProvider>
|
||||
</AddImageToBoardContextProvider>
|
||||
</ImageDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
|
@ -5,15 +5,15 @@ import { useDeleteBoardMutation } from '../../services/api/endpoints/boards';
|
||||
import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { some } from 'lodash-es';
|
||||
import { canvasSelector } from '../../features/canvas/store/canvasSelectors';
|
||||
import { controlNetSelector } from '../../features/controlNet/store/controlNetSlice';
|
||||
import { selectImagesById } from '../../features/gallery/store/imagesSlice';
|
||||
import { nodesSelector } from '../../features/nodes/store/nodesSlice';
|
||||
import { generationSelector } from '../../features/parameters/store/generationSelectors';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
import { selectImagesById } from 'features/gallery/store/gallerySlice';
|
||||
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { RootState } from '../store/store';
|
||||
import { useAppDispatch, useAppSelector } from '../store/storeHooks';
|
||||
import { ImageUsage } from './DeleteImageContext';
|
||||
import { requestedBoardImagesDeletion } from '../../features/gallery/store/actions';
|
||||
import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
|
||||
|
||||
export const selectBoardImagesUsage = createSelector(
|
||||
[
|
||||
|
@ -1,201 +0,0 @@
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import {
|
||||
PropsWithChildren,
|
||||
createContext,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { some } from 'lodash-es';
|
||||
|
||||
export type ImageUsage = {
|
||||
isInitialImage: boolean;
|
||||
isCanvasImage: boolean;
|
||||
isNodesImage: boolean;
|
||||
isControlNetImage: boolean;
|
||||
};
|
||||
|
||||
export const selectImageUsage = createSelector(
|
||||
[
|
||||
generationSelector,
|
||||
canvasSelector,
|
||||
nodesSelector,
|
||||
controlNetSelector,
|
||||
(state: RootState, image_name?: string) => image_name,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, image_name) => {
|
||||
const isInitialImage = generation.initialImage?.imageName === image_name;
|
||||
|
||||
const isCanvasImage = canvas.layerState.objects.some(
|
||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||
);
|
||||
|
||||
const isNodesImage = nodes.nodes.some((node) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) => input.type === 'image' && input.value === image_name
|
||||
);
|
||||
});
|
||||
|
||||
const isControlNetImage = some(
|
||||
controlNet.controlNets,
|
||||
(c) =>
|
||||
c.controlImage === image_name || c.processedControlImage === image_name
|
||||
);
|
||||
|
||||
const imageUsage: ImageUsage = {
|
||||
isInitialImage,
|
||||
isCanvasImage,
|
||||
isNodesImage,
|
||||
isControlNetImage,
|
||||
};
|
||||
|
||||
return imageUsage;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type DeleteImageContextValue = {
|
||||
/**
|
||||
* Whether the delete image dialog is open.
|
||||
*/
|
||||
isOpen: boolean;
|
||||
/**
|
||||
* Closes the delete image dialog.
|
||||
*/
|
||||
onClose: () => void;
|
||||
/**
|
||||
* Opens the delete image dialog and handles all deletion-related checks.
|
||||
*/
|
||||
onDelete: (image?: ImageDTO) => void;
|
||||
/**
|
||||
* The image pending deletion
|
||||
*/
|
||||
image?: ImageDTO;
|
||||
/**
|
||||
* The features in which this image is used
|
||||
*/
|
||||
imageUsage?: ImageUsage;
|
||||
/**
|
||||
* Immediately deletes an image.
|
||||
*
|
||||
* You probably don't want to use this - use `onDelete` instead.
|
||||
*/
|
||||
onImmediatelyDelete: () => void;
|
||||
};
|
||||
|
||||
export const DeleteImageContext = createContext<DeleteImageContextValue>({
|
||||
isOpen: false,
|
||||
onClose: () => undefined,
|
||||
onImmediatelyDelete: () => undefined,
|
||||
onDelete: () => undefined,
|
||||
});
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
|
||||
|
||||
return {
|
||||
canDeleteImage: isConnected && !isProcessing,
|
||||
shouldConfirmOnDelete,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type Props = PropsWithChildren;
|
||||
|
||||
export const DeleteImageContextProvider = (props: Props) => {
|
||||
const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector);
|
||||
const [imageToDelete, setImageToDelete] = useState<ImageDTO>();
|
||||
const dispatch = useAppDispatch();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
// Check where the image to be deleted is used (eg init image, controlnet, etc.)
|
||||
const imageUsage = useAppSelector((state) =>
|
||||
selectImageUsage(state, imageToDelete?.image_name)
|
||||
);
|
||||
|
||||
// Clean up after deleting or dismissing the modal
|
||||
const closeAndClearImageToDelete = useCallback(() => {
|
||||
setImageToDelete(undefined);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
// Dispatch the actual deletion action, to be handled by listener middleware
|
||||
const handleActualDeletion = useCallback(
|
||||
(image: ImageDTO) => {
|
||||
dispatch(requestedImageDeletion({ image, imageUsage }));
|
||||
closeAndClearImageToDelete();
|
||||
},
|
||||
[closeAndClearImageToDelete, dispatch, imageUsage]
|
||||
);
|
||||
|
||||
// This is intended to be called by the delete button in the dialog
|
||||
const onImmediatelyDelete = useCallback(() => {
|
||||
if (canDeleteImage && imageToDelete) {
|
||||
handleActualDeletion(imageToDelete);
|
||||
}
|
||||
closeAndClearImageToDelete();
|
||||
}, [
|
||||
canDeleteImage,
|
||||
imageToDelete,
|
||||
closeAndClearImageToDelete,
|
||||
handleActualDeletion,
|
||||
]);
|
||||
|
||||
const handleGatedDeletion = useCallback(
|
||||
(image: ImageDTO) => {
|
||||
if (shouldConfirmOnDelete || some(imageUsage)) {
|
||||
// If we should confirm on delete, or if the image is in use, open the dialog
|
||||
onOpen();
|
||||
} else {
|
||||
handleActualDeletion(image);
|
||||
}
|
||||
},
|
||||
[imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion]
|
||||
);
|
||||
|
||||
// Consumers of the context call this to delete an image
|
||||
const onDelete = useCallback((image?: ImageDTO) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
// Set the image to delete, then let the effect call the actual deletion
|
||||
setImageToDelete(image);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// We need to use an effect here to trigger the image usage selector, else we get a stale value
|
||||
if (imageToDelete) {
|
||||
handleGatedDeletion(imageToDelete);
|
||||
}
|
||||
}, [handleGatedDeletion, imageToDelete]);
|
||||
|
||||
return (
|
||||
<DeleteImageContext.Provider
|
||||
value={{
|
||||
isOpen,
|
||||
image: imageToDelete,
|
||||
onClose: closeAndClearImageToDelete,
|
||||
onDelete,
|
||||
onImmediatelyDelete,
|
||||
imageUsage,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</DeleteImageContext.Provider>
|
||||
);
|
||||
};
|
@ -20,10 +20,8 @@ const serializationDenylist: {
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
// config: configPersistDenyList,
|
||||
ui: uiPersistDenylist,
|
||||
controlNet: controlNetDenylist,
|
||||
// hotkeys: hotkeysPersistDenylist,
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
|
||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { initialImagesState } from 'features/gallery/store/imagesSlice';
|
||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
@ -26,7 +25,6 @@ const initialStates: {
|
||||
config: initialConfigState,
|
||||
ui: initialUIState,
|
||||
hotkeys: initialHotkeysState,
|
||||
images: initialImagesState,
|
||||
controlNet: initialControlNetState,
|
||||
};
|
||||
|
||||
|
@ -72,7 +72,6 @@ import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingA
|
||||
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||
import {
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
@ -84,6 +83,9 @@ import {
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
|
||||
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -126,6 +128,7 @@ addImageDeletedPendingListener();
|
||||
addImageDeletedFulfilledListener();
|
||||
addImageDeletedRejectedListener();
|
||||
addRequestedBoardImageDeletionListener();
|
||||
addImageToDeleteSelectedListener();
|
||||
|
||||
// Image metadata
|
||||
addImageMetadataReceivedFulfilledListener();
|
||||
@ -211,3 +214,9 @@ addBoardIdSelectedListener();
|
||||
|
||||
// Node schemas
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Batches
|
||||
addSelectionAddedToBatchListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
@ -1,12 +1,14 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
|
||||
import {
|
||||
imageSelected,
|
||||
selectImagesAll,
|
||||
boardIdSelected,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
IMAGES_PER_PAGE,
|
||||
receivedPageOfImages,
|
||||
} from 'services/api/thunks/image';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
@ -28,7 +30,7 @@ export const addBoardIdSelectedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
const { categories } = state.gallery;
|
||||
|
||||
const filteredImages = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
@ -47,7 +49,7 @@ export const addBoardIdSelectedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageSelected(board.cover_image_name));
|
||||
dispatch(imageSelected(board.cover_image_name ?? null));
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
@ -77,7 +79,7 @@ export const addBoardIdSelected_changeSelectedImage_listener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
const { categories } = state.gallery;
|
||||
|
||||
const filteredImages = selectImagesAll(state).filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
imagesRemoved,
|
||||
selectImagesAll,
|
||||
selectImagesById,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
@ -22,12 +22,15 @@ export const addRequestedBoardImageDeletionListener = () => {
|
||||
const { board_id } = board;
|
||||
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage
|
||||
? selectImagesById(state, state.gallery.selectedImage)
|
||||
const selectedImageName =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
|
||||
const selectedImage = selectedImageName
|
||||
? selectImagesById(state, selectedImageName)
|
||||
: undefined;
|
||||
|
||||
if (selectedImage && selectedImage.board_id === board_id) {
|
||||
dispatch(imageSelected());
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
// We need to reset the features where the board images are in use - none of these work if their image(s) don't exist
|
||||
|
@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/api/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
|
@ -3,8 +3,8 @@ import { startAppListening } from '..';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import {
|
||||
imageCategoriesChanged,
|
||||
selectFilteredImagesAsArray,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
selectFilteredImages,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
@ -13,7 +13,7 @@ export const addImageCategoriesChangedListener = () => {
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
|
||||
const filteredImagesCount = selectFilteredImages(state).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(
|
||||
|
@ -1,18 +1,21 @@
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
selectImagesIds,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
imageSelected,
|
||||
selectFilteredImages,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageDeletionConfirmed,
|
||||
isModalOpenChanged,
|
||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
@ -21,17 +24,22 @@ const moduleLog = log.child({ namespace: 'image' });
|
||||
*/
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
actionCreator: imageDeletionConfirmed,
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const { image, imageUsage } = action.payload;
|
||||
const { imageDTO, imageUsage } = action.payload;
|
||||
|
||||
const { image_name } = image;
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
const lastSelectedImage =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
|
||||
if (selectedImage === image_name) {
|
||||
const ids = selectImagesIds(state);
|
||||
if (lastSelectedImage === image_name) {
|
||||
const filteredImages = selectFilteredImages(state);
|
||||
|
||||
const ids = filteredImages.map((i) => i.image_name);
|
||||
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
@ -50,7 +58,7 @@ export const addRequestedImageDeletionListener = () => {
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImageId as string));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,7 +96,7 @@ export const addRequestedImageDeletionListener = () => {
|
||||
|
||||
if (wasImageDeleted) {
|
||||
dispatch(
|
||||
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
|
||||
api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }])
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -0,0 +1,188 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
imageAddedToBatch,
|
||||
imagesAddedToBatch,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
fieldValueChanged,
|
||||
imageCollectionFieldValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
import { startAppListening } from '../';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'dnd' });
|
||||
|
||||
export const imageDropped = createAction<{
|
||||
overData: TypesafeDroppableData;
|
||||
activeData: TypesafeDraggableData;
|
||||
}>('dnd/imageDropped');
|
||||
|
||||
export const addImageDroppedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageDropped,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { activeData, overData } = action.payload;
|
||||
const { actionType } = overData;
|
||||
const state = getState();
|
||||
|
||||
// set current image
|
||||
if (
|
||||
actionType === 'SET_CURRENT_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
|
||||
}
|
||||
|
||||
// set initial image
|
||||
if (
|
||||
actionType === 'SET_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(initialImageChanged(activeData.payload.imageDTO));
|
||||
}
|
||||
|
||||
// add image to batch
|
||||
if (
|
||||
actionType === 'ADD_TO_BATCH' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageAddedToBatch(activeData.payload.imageDTO.image_name));
|
||||
}
|
||||
|
||||
// add multiple images to batch
|
||||
if (
|
||||
actionType === 'ADD_TO_BATCH' &&
|
||||
activeData.payloadType === 'GALLERY_SELECTION'
|
||||
) {
|
||||
dispatch(imagesAddedToBatch(state.gallery.selection));
|
||||
}
|
||||
|
||||
// set control image
|
||||
if (
|
||||
actionType === 'SET_CONTROLNET_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { controlNetId } = overData.context;
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlImage: activeData.payload.imageDTO.image_name,
|
||||
controlNetId,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// set canvas image
|
||||
if (
|
||||
actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(setInitialCanvasImage(activeData.payload.imageDTO));
|
||||
}
|
||||
|
||||
// set nodes image
|
||||
if (
|
||||
actionType === 'SET_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: activeData.payload.imageDTO,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// set multiple nodes images (single image handler)
|
||||
if (
|
||||
actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: [activeData.payload.imageDTO],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// set multiple nodes images (multiple images handler)
|
||||
if (
|
||||
actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
activeData.payloadType === 'GALLERY_SELECTION'
|
||||
) {
|
||||
const { fieldName, nodeId } = overData.context;
|
||||
dispatch(
|
||||
imageCollectionFieldValueChanged({
|
||||
nodeId,
|
||||
fieldName,
|
||||
value: state.gallery.selection.map((image_name) => ({
|
||||
image_name,
|
||||
})),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// remove image from board
|
||||
// TODO: remove board_id from `removeImageFromBoard()` endpoint
|
||||
// TODO: handle multiple images
|
||||
// if (
|
||||
// actionType === 'MOVE_BOARD' &&
|
||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||
// activeData.payload.imageDTO &&
|
||||
// overData.boardId !== null
|
||||
// ) {
|
||||
// const { image_name } = activeData.payload.imageDTO;
|
||||
// dispatch(
|
||||
// boardImagesApi.endpoints.removeImageFromBoard.initiate({ image_name })
|
||||
// );
|
||||
// }
|
||||
|
||||
// add image to board
|
||||
if (
|
||||
actionType === 'MOVE_BOARD' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO &&
|
||||
overData.context.boardId
|
||||
) {
|
||||
const { image_name } = activeData.payload.imageDTO;
|
||||
const { boardId } = overData.context;
|
||||
dispatch(
|
||||
boardImagesApi.endpoints.addImageToBoard.initiate({
|
||||
image_name,
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// add multiple images to board
|
||||
// TODO: add endpoint
|
||||
// if (
|
||||
// actionType === 'ADD_TO_BATCH' &&
|
||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
||||
// activeData.payload.imageDTONames
|
||||
// ) {
|
||||
// dispatch(boardImagesApi.endpoints.addImagesToBoard.intiate({}));
|
||||
// }
|
||||
},
|
||||
});
|
||||
};
|
@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
imageDeletionConfirmed,
|
||||
imageToDeleteSelected,
|
||||
isModalOpenChanged,
|
||||
selectImageUsage,
|
||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
export const addImageToDeleteSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageToDeleteSelected,
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const imageDTO = action.payload;
|
||||
const state = getState();
|
||||
const { shouldConfirmOnDelete } = state.system;
|
||||
const imageUsage = selectImageUsage(getState());
|
||||
|
||||
if (!imageUsage) {
|
||||
// should never happen
|
||||
return;
|
||||
}
|
||||
|
||||
const isImageInUse =
|
||||
imageUsage.isCanvasImage ||
|
||||
imageUsage.isInitialImage ||
|
||||
imageUsage.isControlNetImage ||
|
||||
imageUsage.isNodesImage;
|
||||
|
||||
if (shouldConfirmOnDelete || isImageInUse) {
|
||||
dispatch(isModalOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
|
||||
},
|
||||
});
|
||||
};
|
@ -2,11 +2,12 @@ import { startAppListening } from '..';
|
||||
import { imageUploaded } from 'services/api/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/gallerySlice';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { imageAddedToBatch } from 'features/batch/store/batchSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
@ -70,6 +71,11 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'ADD_TO_BATCH') {
|
||||
dispatch(imageAddedToBatch(image.image_name));
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { imageUpdatedOne } from 'features/gallery/store/imagesSlice';
|
||||
import { imageUpdatedOne } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
|
@ -4,7 +4,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
||||
import { selectImagesById } from 'features/gallery/store/gallerySlice';
|
||||
import { isImageDTO } from 'services/api/guards';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
|
@ -2,6 +2,7 @@ import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
@ -9,11 +10,17 @@ export const addReceivedPageOfImagesFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedPageOfImages.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const page = action.payload;
|
||||
const { items } = action.payload;
|
||||
moduleLog.debug(
|
||||
{ data: { payload: action.payload } },
|
||||
`Received ${page.items.length} images`
|
||||
`Received ${items.length} images`
|
||||
);
|
||||
|
||||
items.forEach((image) => {
|
||||
dispatch(
|
||||
imagesApi.util.upsertQueryData('getImageDTO', image.image_name, image)
|
||||
);
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,19 @@
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
imagesAddedToBatch,
|
||||
selectionAddedToBatch,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'batch' });
|
||||
|
||||
export const addSelectionAddedToBatchListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: selectionAddedToBatch,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { selection } = getState().gallery;
|
||||
|
||||
dispatch(imagesAddedToBatch(selection));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,6 +1,5 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { nodes, config, images } = getState();
|
||||
const { nodes, config } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!images.ids.length) {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUpdated } from 'services/api/thunks/image';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/gallerySlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvas' });
|
||||
|
@ -8,7 +8,7 @@ import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
import { forEach, uniqBy } from 'lodash-es';
|
||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { selectImagesEntities } from 'features/gallery/store/imagesSlice';
|
||||
import { selectImagesEntities } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'images' });
|
||||
|
||||
@ -36,7 +36,7 @@ const selectAllUsedImages = createSelector(
|
||||
nodes.nodes.forEach((node) => {
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (input.type === 'image' && input.value) {
|
||||
allUsedImages.push(input.value);
|
||||
allUsedImages.push(input.value.image_name);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
@ -8,31 +8,32 @@ import {
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
|
||||
import batchReducer from 'features/batch/store/batchSlice';
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import loraReducer from 'features/lora/store/loraSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { api } from 'services/api';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { api } from 'services/api';
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
@ -45,11 +46,12 @@ const allReducers = {
|
||||
config: configReducer,
|
||||
ui: uiReducer,
|
||||
hotkeys: hotkeysReducer,
|
||||
images: imagesReducer,
|
||||
controlNet: controlNetReducer,
|
||||
boards: boardsReducer,
|
||||
// session: sessionReducer,
|
||||
dynamicPrompts: dynamicPromptsReducer,
|
||||
batch: batchReducer,
|
||||
imageDeletion: imageDeletionReducer,
|
||||
lora: loraReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@ -68,6 +70,8 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'ui',
|
||||
'controlNet',
|
||||
'dynamicPrompts',
|
||||
'batch',
|
||||
'lora',
|
||||
// 'boards',
|
||||
// 'hotkeys',
|
||||
// 'config',
|
||||
|
@ -15,10 +15,25 @@ export interface IAIButtonProps extends ButtonProps {
|
||||
}
|
||||
|
||||
const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => {
|
||||
const { children, tooltip = '', tooltipProps, isChecked, ...rest } = props;
|
||||
const {
|
||||
children,
|
||||
tooltip = '',
|
||||
tooltipProps: { placement = 'top', hasArrow = true, ...tooltipProps } = {},
|
||||
isChecked,
|
||||
...rest
|
||||
} = props;
|
||||
return (
|
||||
<Tooltip label={tooltip} {...tooltipProps}>
|
||||
<Button ref={forwardedRef} aria-checked={isChecked} {...rest}>
|
||||
<Tooltip
|
||||
label={tooltip}
|
||||
placement={placement}
|
||||
hasArrow={hasArrow}
|
||||
{...tooltipProps}
|
||||
>
|
||||
<Button
|
||||
ref={forwardedRef}
|
||||
colorScheme={isChecked ? 'accent' : 'base'}
|
||||
{...rest}
|
||||
>
|
||||
{children}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
|
@ -4,22 +4,25 @@ import {
|
||||
Collapse,
|
||||
Flex,
|
||||
Spacer,
|
||||
Switch,
|
||||
Text,
|
||||
useColorMode,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
export type IAIToggleCollapseProps = PropsWithChildren & {
|
||||
label: string;
|
||||
isOpen: boolean;
|
||||
onToggle: () => void;
|
||||
withSwitch?: boolean;
|
||||
activeLabel?: string;
|
||||
defaultIsOpen?: boolean;
|
||||
};
|
||||
|
||||
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
const { label, isOpen, onToggle, children, withSwitch = false } = props;
|
||||
const { label, activeLabel, children, defaultIsOpen = false } = props;
|
||||
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen });
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Flex
|
||||
@ -28,6 +31,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
alignItems: 'center',
|
||||
p: 2,
|
||||
px: 4,
|
||||
gap: 2,
|
||||
borderTopRadius: 'base',
|
||||
borderBottomRadius: isOpen ? 0 : 'base',
|
||||
bg: isOpen
|
||||
@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
<AnimatePresence>
|
||||
{activeLabel && (
|
||||
<motion.div
|
||||
key="statusText"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
sx={{ color: 'accent.500', _dark: { color: 'accent.300' } }}
|
||||
>
|
||||
{activeLabel}
|
||||
</Text>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<Spacer />
|
||||
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
|
||||
{!withSwitch && (
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
w: '1rem',
|
||||
h: '1rem',
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
w: '1rem',
|
||||
h: '1rem',
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||
<Box
|
||||
|
@ -1,19 +1,20 @@
|
||||
import {
|
||||
Box,
|
||||
ChakraProps,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButtonProps,
|
||||
Image,
|
||||
useColorMode,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||
import { useCombinedRefs } from '@dnd-kit/utilities';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import {
|
||||
IAILoadingImageFallback,
|
||||
IAINoContentFallback,
|
||||
} from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { ReactElement, SyntheticEvent } from 'react';
|
||||
import { MouseEvent, ReactElement, SyntheticEvent } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
@ -22,81 +23,97 @@ import IAIDropOverlay from './IAIDropOverlay';
|
||||
import { PostUploadAction } from 'services/api/thunks/image';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
isValidDrop,
|
||||
useDraggable,
|
||||
useDroppable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
|
||||
type IAIDndImageProps = {
|
||||
image: ImageDTO | null | undefined;
|
||||
onDrop: (droppedImage: ImageDTO) => void;
|
||||
onReset?: () => void;
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
resetIconSize?: IconButtonProps['size'];
|
||||
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
|
||||
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
withResetIcon?: boolean;
|
||||
resetIcon?: ReactElement;
|
||||
resetTooltip?: string;
|
||||
withMetadataOverlay?: boolean;
|
||||
isDragDisabled?: boolean;
|
||||
isDropDisabled?: boolean;
|
||||
isUploadDisabled?: boolean;
|
||||
fallback?: ReactElement;
|
||||
payloadImage?: ImageDTO | null | undefined;
|
||||
minSize?: number;
|
||||
postUploadAction?: PostUploadAction;
|
||||
imageSx?: ChakraProps['sx'];
|
||||
fitContainer?: boolean;
|
||||
droppableData?: TypesafeDroppableData;
|
||||
draggableData?: TypesafeDraggableData;
|
||||
dropLabel?: string;
|
||||
isSelected?: boolean;
|
||||
thumbnail?: boolean;
|
||||
noContentFallback?: ReactElement;
|
||||
};
|
||||
|
||||
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
const {
|
||||
image,
|
||||
onDrop,
|
||||
onReset,
|
||||
imageDTO,
|
||||
onClickReset,
|
||||
onError,
|
||||
resetIconSize = 'md',
|
||||
onClick,
|
||||
withResetIcon = false,
|
||||
withMetadataOverlay = false,
|
||||
isDropDisabled = false,
|
||||
isDragDisabled = false,
|
||||
isUploadDisabled = false,
|
||||
fallback = <IAIImageLoadingFallback />,
|
||||
payloadImage,
|
||||
minSize = 24,
|
||||
postUploadAction,
|
||||
imageSx,
|
||||
fitContainer = false,
|
||||
droppableData,
|
||||
draggableData,
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
thumbnail = false,
|
||||
resetTooltip = 'Reset',
|
||||
resetIcon = <FaUndo />,
|
||||
noContentFallback = <IAINoContentFallback icon={FaImage} />,
|
||||
} = props;
|
||||
|
||||
const dndId = useRef(uuidv4());
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef: setDroppableRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
id: dndId.current,
|
||||
disabled: isDropDisabled,
|
||||
data: {
|
||||
handleDrop: onDrop,
|
||||
},
|
||||
});
|
||||
const dndId = useRef(uuidv4());
|
||||
|
||||
const {
|
||||
attributes,
|
||||
listeners,
|
||||
setNodeRef: setDraggableRef,
|
||||
isDragging,
|
||||
active,
|
||||
} = useDraggable({
|
||||
id: dndId.current,
|
||||
data: {
|
||||
image: payloadImage ? payloadImage : image,
|
||||
},
|
||||
disabled: isDragDisabled || !image,
|
||||
disabled: isDragDisabled || !imageDTO,
|
||||
data: draggableData,
|
||||
});
|
||||
|
||||
const { isOver, setNodeRef: setDroppableRef } = useDroppable({
|
||||
id: dndId.current,
|
||||
disabled: isDropDisabled,
|
||||
data: droppableData,
|
||||
});
|
||||
|
||||
const setDndRef = useCombinedRefs(setDroppableRef, setDraggableRef);
|
||||
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
isDisabled: isUploadDisabled,
|
||||
});
|
||||
|
||||
const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef);
|
||||
const resetIconShadow = useColorModeValue(
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
|
||||
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
|
||||
);
|
||||
|
||||
const uploadButtonStyles = isUploadDisabled
|
||||
? {}
|
||||
@ -117,16 +134,16 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'relative',
|
||||
minW: minSize,
|
||||
minH: minSize,
|
||||
minW: minSize ? minSize : undefined,
|
||||
minH: minSize ? minSize : undefined,
|
||||
userSelect: 'none',
|
||||
cursor: isDragDisabled || !image ? 'auto' : 'grab',
|
||||
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
|
||||
}}
|
||||
{...attributes}
|
||||
{...listeners}
|
||||
ref={setNodeRef}
|
||||
ref={setDndRef}
|
||||
>
|
||||
{image && (
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
@ -137,42 +154,50 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
src={image.image_url}
|
||||
fallback={fallback}
|
||||
onClick={onClick}
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
fallback={<IAILoadingImageFallback image={imageDTO} />}
|
||||
onError={onError}
|
||||
objectFit="contain"
|
||||
draggable={false}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
...imageSx,
|
||||
}}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay image={image} />}
|
||||
{onReset && withResetIcon && (
|
||||
<Box
|
||||
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
|
||||
{onClickReset && withResetIcon && (
|
||||
<IAIIconButton
|
||||
onClick={onClickReset}
|
||||
aria-label={resetTooltip}
|
||||
tooltip={resetTooltip}
|
||||
icon={resetIcon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: resetIconShadow,
|
||||
},
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
size={resetIconSize}
|
||||
tooltip="Reset Image"
|
||||
aria-label="Reset Image"
|
||||
icon={<FaUndo />}
|
||||
onClick={onReset}
|
||||
/>
|
||||
</Box>
|
||||
/>
|
||||
)}
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
)}
|
||||
{!image && (
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<>
|
||||
<Flex
|
||||
sx={{
|
||||
@ -191,17 +216,20 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
>
|
||||
<input {...getUploadInputProps()} />
|
||||
<Icon
|
||||
as={isUploadDisabled ? FaImage : FaUpload}
|
||||
as={FaUpload}
|
||||
sx={{
|
||||
boxSize: 12,
|
||||
boxSize: 16,
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
<AnimatePresence>
|
||||
{isValidDrop(droppableData, active) && !isDragging && (
|
||||
<IAIDropOverlay isOver={isOver} label={dropLabel} />
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -62,7 +62,7 @@ export const IAIDropOverlay = (props: Props) => {
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
opacity: 1,
|
||||
borderWidth: 2,
|
||||
borderWidth: 3,
|
||||
borderColor: isOver
|
||||
? mode('base.50', 'base.200')(colorMode)
|
||||
: mode('base.100', 'base.500')(colorMode),
|
||||
@ -78,10 +78,10 @@ export const IAIDropOverlay = (props: Props) => {
|
||||
sx={{
|
||||
fontSize: '2xl',
|
||||
fontWeight: 600,
|
||||
transform: isOver ? 'scale(1.1)' : 'scale(1)',
|
||||
transform: isOver ? 'scale(1.02)' : 'scale(1)',
|
||||
color: isOver
|
||||
? mode('base.100', 'base.100')(colorMode)
|
||||
: mode('base.200', 'base.500')(colorMode),
|
||||
? mode('base.50', 'base.50')(colorMode)
|
||||
: mode('base.100', 'base.200')(colorMode),
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
|
@ -29,7 +29,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
|
||||
<IconButton
|
||||
ref={forwardedRef}
|
||||
role={role}
|
||||
aria-checked={isChecked !== undefined ? isChecked : undefined}
|
||||
colorScheme={isChecked ? 'accent' : 'base'}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
|
@ -1,73 +1,82 @@
|
||||
import {
|
||||
As,
|
||||
ChakraProps,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Icon,
|
||||
IconProps,
|
||||
Skeleton,
|
||||
Spinner,
|
||||
SpinnerProps,
|
||||
useColorMode,
|
||||
StyleProps,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type Props = FlexProps & {
|
||||
spinnerProps?: SpinnerProps;
|
||||
};
|
||||
type Props = { image: ImageDTO | undefined };
|
||||
|
||||
export const IAILoadingImageFallback = (props: Props) => {
|
||||
if (props.image) {
|
||||
return (
|
||||
<Skeleton
|
||||
sx={{
|
||||
w: `${props.image.width}px`,
|
||||
h: 'auto',
|
||||
objectFit: 'contain',
|
||||
aspectRatio: `${props.image.width}/${props.image.height}`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export const IAIImageLoadingFallback = (props: Props) => {
|
||||
const { spinnerProps, ...rest } = props;
|
||||
const { sx, ...restFlexProps } = rest;
|
||||
const { colorMode } = useColorMode();
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
bg: mode('base.200', 'base.900')(colorMode),
|
||||
opacity: 0.7,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
...sx,
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.900',
|
||||
},
|
||||
}}
|
||||
{...restFlexProps}
|
||||
>
|
||||
<Spinner size="xl" {...spinnerProps} />
|
||||
<Spinner size="xl" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type IAINoImageFallbackProps = {
|
||||
flexProps?: FlexProps;
|
||||
iconProps?: IconProps;
|
||||
as?: As;
|
||||
label?: string;
|
||||
icon?: As;
|
||||
boxSize?: StyleProps['boxSize'];
|
||||
sx?: ChakraProps['sx'];
|
||||
};
|
||||
|
||||
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
|
||||
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
|
||||
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
|
||||
const { colorMode } = useColorMode();
|
||||
export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||
const { icon = FaImage, boxSize = 16 } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
bg: mode('base.200', 'base.900')(colorMode),
|
||||
opacity: 0.7,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
...flexSx,
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
userSelect: 'none',
|
||||
color: 'base.700',
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
...props.sx,
|
||||
}}
|
||||
{...restFlexProps}
|
||||
>
|
||||
<Icon
|
||||
as={props.as ?? FaImage}
|
||||
sx={{ color: mode('base.700', 'base.500')(colorMode), ...iconSx }}
|
||||
{...restIconProps}
|
||||
/>
|
||||
<Icon as={icon} boxSize={boxSize} opacity={0.7} />
|
||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,15 +1,16 @@
|
||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { memo } from 'react';
|
||||
import { RefObject, memo } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const {
|
||||
base50,
|
||||
base100,
|
||||
@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
@ -61,7 +63,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
'&:focus-within': {
|
||||
borderColor: mode(accent200, accent600)(colorMode),
|
||||
},
|
||||
'&:disabled': {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
},
|
||||
|
@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
'&:focus-within': {
|
||||
borderColor: mode(accent200, accent600)(colorMode),
|
||||
},
|
||||
'&:disabled': {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
},
|
||||
|
@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => {
|
||||
isDisabled={isDisabled}
|
||||
width={width}
|
||||
display="flex"
|
||||
gap={4}
|
||||
alignItems="center"
|
||||
{...formControlProps}
|
||||
>
|
||||
@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => {
|
||||
sx={{
|
||||
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||
...formLabelProps?.sx,
|
||||
pe: 4,
|
||||
}}
|
||||
{...formLabelProps}
|
||||
>
|
||||
|
@ -1,27 +1,49 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
modelsApi,
|
||||
useGetMainModelsQuery,
|
||||
} from '../../services/api/endpoints/models';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[generationSelector, systemSelector, activeTabNameSelector],
|
||||
(generation, system, activeTabName) => {
|
||||
[stateSelector, activeTabNameSelector],
|
||||
(state, activeTabName) => {
|
||||
const { generation, system, batch } = state;
|
||||
const { shouldGenerateVariations, seedWeights, initialImage, seed } =
|
||||
generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
const {
|
||||
isEnabled: isBatchEnabled,
|
||||
asInitialImage,
|
||||
imageNames: batchImageNames,
|
||||
} = batch;
|
||||
|
||||
let isReady = true;
|
||||
const reasonsWhyNotReady: string[] = [];
|
||||
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
if (
|
||||
activeTabName === 'img2img' &&
|
||||
!initialImage &&
|
||||
!(asInitialImage && batchImageNames.length > 1)
|
||||
) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
}
|
||||
|
||||
// TODO: job queue
|
||||
// Cannot generate if already processing an image
|
||||
if (isProcessing) {
|
||||
|
@ -0,0 +1,67 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
Spacer,
|
||||
Switch,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { controlNetToggled } from '../store/batchSlice';
|
||||
|
||||
type Props = {
|
||||
controlNet: ControlNetConfig;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector, (state, controlNetId: string) => controlNetId],
|
||||
(state, controlNetId) => {
|
||||
const isControlNetEnabled = state.batch.controlNets.includes(controlNetId);
|
||||
return { isControlNetEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const BatchControlNet = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { isControlNetEnabled } = useAppSelector((state) =>
|
||||
selector(state, props.controlNet.controlNetId)
|
||||
);
|
||||
const { processorType, model } = props.controlNet;
|
||||
|
||||
const handleChangeAsControlNet = useCallback(() => {
|
||||
dispatch(controlNetToggled(props.controlNet.controlNetId));
|
||||
}, [dispatch, props.controlNet.controlNetId]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
sx={{ flexDir: 'column', gap: 1, p: 4, borderRadius: 'base' }}
|
||||
>
|
||||
<Flex sx={{ justifyContent: 'space-between' }}>
|
||||
<FormControl as={Flex} onClick={handleChangeAsControlNet}>
|
||||
<FormLabel>
|
||||
<Heading size="sm">ControlNet</Heading>
|
||||
</FormLabel>
|
||||
<Spacer />
|
||||
<Switch isChecked={isControlNetEnabled} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Text>
|
||||
<strong>Model:</strong> {model}
|
||||
</Text>
|
||||
<Text>
|
||||
<strong>Processor:</strong> {processorType}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BatchControlNet);
|
@ -0,0 +1,116 @@
|
||||
import { Box, Icon, Skeleton } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import {
|
||||
batchImageRangeEndSelected,
|
||||
batchImageSelected,
|
||||
batchImageSelectionToggled,
|
||||
imageRemovedFromBatch,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { FaExclamationCircle } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
const makeSelector = (image_name: string) =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
(state) => ({
|
||||
selectionCount: state.batch.selection.length,
|
||||
isSelected: state.batch.selection.includes(image_name),
|
||||
}),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type BatchImageProps = {
|
||||
imageName: string;
|
||||
};
|
||||
|
||||
const BatchImage = (props: BatchImageProps) => {
|
||||
const {
|
||||
currentData: imageDTO,
|
||||
isFetching,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(props.imageName);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() => makeSelector(props.imageName),
|
||||
[props.imageName]
|
||||
);
|
||||
|
||||
const { isSelected, selectionCount } = useAppSelector(selector);
|
||||
|
||||
const handleClickRemove = useCallback(() => {
|
||||
dispatch(imageRemovedFromBatch(props.imageName));
|
||||
}, [dispatch, props.imageName]);
|
||||
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(batchImageRangeEndSelected(props.imageName));
|
||||
} else if (e.ctrlKey || e.metaKey) {
|
||||
dispatch(batchImageSelectionToggled(props.imageName));
|
||||
} else {
|
||||
dispatch(batchImageSelected(props.imageName));
|
||||
}
|
||||
},
|
||||
[dispatch, props.imageName]
|
||||
);
|
||||
|
||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||
if (selectionCount > 1) {
|
||||
return {
|
||||
id: 'batch',
|
||||
payloadType: 'BATCH_SELECTION',
|
||||
};
|
||||
}
|
||||
|
||||
if (imageDTO) {
|
||||
return {
|
||||
id: 'batch',
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO },
|
||||
};
|
||||
}
|
||||
}, [imageDTO, selectionCount]);
|
||||
|
||||
if (isError) {
|
||||
return <Icon as={FaExclamationCircle} />;
|
||||
}
|
||||
|
||||
if (isFetching) {
|
||||
return (
|
||||
<Skeleton>
|
||||
<Box w="full" h="full" aspectRatio="1/1" />
|
||||
</Skeleton>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box sx={{ position: 'relative', aspectRatio: '1/1' }}>
|
||||
<IAIDndImage
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
imageSx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
onClick={handleClick}
|
||||
isSelected={isSelected}
|
||||
onClickReset={handleClickRemove}
|
||||
resetTooltip="Remove from batch"
|
||||
withResetIcon
|
||||
thumbnail
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BatchImage);
|
@ -0,0 +1,31 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import BatchImageGrid from './BatchImageGrid';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import {
|
||||
AddToBatchDropData,
|
||||
isValidDrop,
|
||||
useDroppable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
|
||||
const droppableData: AddToBatchDropData = {
|
||||
id: 'batch',
|
||||
actionType: 'ADD_TO_BATCH',
|
||||
};
|
||||
|
||||
const BatchImageContainer = () => {
|
||||
const { isOver, setNodeRef, active } = useDroppable({
|
||||
id: 'batch-manager',
|
||||
data: droppableData,
|
||||
});
|
||||
|
||||
return (
|
||||
<Box ref={setNodeRef} position="relative" w="full" h="full">
|
||||
<BatchImageGrid />
|
||||
{isValidDrop(droppableData, active) && (
|
||||
<IAIDropOverlay isOver={isOver} label="Add to Batch" />
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default BatchImageContainer;
|
@ -0,0 +1,54 @@
|
||||
import { FaImages } from 'react-icons/fa';
|
||||
import { Grid, GridItem } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import BatchImage from './BatchImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const imageNames = state.batch.imageNames.concat().reverse();
|
||||
|
||||
return { imageNames };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const BatchImageGrid = () => {
|
||||
const { imageNames } = useAppSelector(selector);
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
return (
|
||||
<IAINoContentFallback
|
||||
icon={FaImages}
|
||||
boxSize={16}
|
||||
label="No images in Batch"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Grid
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
flexWrap: 'wrap',
|
||||
w: 'full',
|
||||
minH: 0,
|
||||
maxH: 'full',
|
||||
overflowY: 'scroll',
|
||||
gridTemplateColumns: `repeat(auto-fill, minmax(128px, 1fr))`,
|
||||
}}
|
||||
>
|
||||
{imageNames.map((imageName) => (
|
||||
<GridItem key={imageName} sx={{ p: 1.5 }}>
|
||||
<BatchImage imageName={imageName} />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
export default BatchImageGrid;
|
@ -0,0 +1,103 @@
|
||||
import { Flex, Heading, Spacer } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import {
|
||||
asInitialImageToggled,
|
||||
batchReset,
|
||||
isEnabledChanged,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import BatchImageContainer from './BatchImageGrid';
|
||||
import { map } from 'lodash-es';
|
||||
import BatchControlNet from './BatchControlNet';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { controlNets } = state.controlNet;
|
||||
const {
|
||||
imageNames,
|
||||
asInitialImage,
|
||||
controlNets: batchControlNets,
|
||||
isEnabled,
|
||||
} = state.batch;
|
||||
|
||||
return {
|
||||
imageCount: imageNames.length,
|
||||
asInitialImage,
|
||||
controlNets,
|
||||
batchControlNets,
|
||||
isEnabled,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const BatchManager = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { imageCount, isEnabled, controlNets, batchControlNets } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const handleResetBatch = useCallback(() => {
|
||||
dispatch(batchReset());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleToggle = useCallback(() => {
|
||||
dispatch(isEnabledChanged(!isEnabled));
|
||||
}, [dispatch, isEnabled]);
|
||||
|
||||
const handleChangeAsInitialImage = useCallback(() => {
|
||||
dispatch(asInitialImageToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
flexDir: 'column',
|
||||
position: 'relative',
|
||||
gap: 2,
|
||||
minW: 0,
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ alignItems: 'center' }}>
|
||||
<Heading
|
||||
size={'md'}
|
||||
sx={{ color: 'base.800', _dark: { color: 'base.200' } }}
|
||||
>
|
||||
{imageCount || 'No'} images
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<IAIButton onClick={handleResetBatch}>Reset</IAIButton>
|
||||
</Flex>
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
flexDir: 'column',
|
||||
gap: 4,
|
||||
}}
|
||||
>
|
||||
<IAISwitch
|
||||
label="Use as Initial Image"
|
||||
onChange={handleChangeAsInitialImage}
|
||||
/>
|
||||
{map(controlNets, (controlNet) => {
|
||||
return (
|
||||
<BatchControlNet
|
||||
key={controlNet.controlNetId}
|
||||
controlNet={controlNet}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Flex>
|
||||
<BatchImageContainer />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default BatchManager;
|
142
invokeai/frontend/web/src/features/batch/store/batchSlice.ts
Normal file
142
invokeai/frontend/web/src/features/batch/store/batchSlice.ts
Normal file
@ -0,0 +1,142 @@
|
||||
import { PayloadAction, createAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { uniq } from 'lodash-es';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
|
||||
type BatchState = {
|
||||
isEnabled: boolean;
|
||||
imageNames: string[];
|
||||
asInitialImage: boolean;
|
||||
controlNets: string[];
|
||||
selection: string[];
|
||||
};
|
||||
|
||||
export const initialBatchState: BatchState = {
|
||||
isEnabled: false,
|
||||
imageNames: [],
|
||||
asInitialImage: false,
|
||||
controlNets: [],
|
||||
selection: [],
|
||||
};
|
||||
|
||||
const batch = createSlice({
|
||||
name: 'batch',
|
||||
initialState: initialBatchState,
|
||||
reducers: {
|
||||
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isEnabled = action.payload;
|
||||
},
|
||||
imageAddedToBatch: (state, action: PayloadAction<string>) => {
|
||||
state.imageNames = uniq(state.imageNames.concat(action.payload));
|
||||
},
|
||||
imagesAddedToBatch: (state, action: PayloadAction<string[]>) => {
|
||||
state.imageNames = uniq(state.imageNames.concat(action.payload));
|
||||
},
|
||||
imageRemovedFromBatch: (state, action: PayloadAction<string>) => {
|
||||
state.imageNames = state.imageNames.filter(
|
||||
(imageName) => action.payload !== imageName
|
||||
);
|
||||
state.selection = state.selection.filter(
|
||||
(imageName) => action.payload !== imageName
|
||||
);
|
||||
},
|
||||
imagesRemovedFromBatch: (state, action: PayloadAction<string[]>) => {
|
||||
state.imageNames = state.imageNames.filter(
|
||||
(imageName) => !action.payload.includes(imageName)
|
||||
);
|
||||
state.selection = state.selection.filter(
|
||||
(imageName) => !action.payload.includes(imageName)
|
||||
);
|
||||
},
|
||||
batchImageRangeEndSelected: (state, action: PayloadAction<string>) => {
|
||||
const rangeEndImageName = action.payload;
|
||||
const lastSelectedImage = state.selection[state.selection.length - 1];
|
||||
const lastClickedIndex = state.imageNames.findIndex(
|
||||
(n) => n === lastSelectedImage
|
||||
);
|
||||
const currentClickedIndex = state.imageNames.findIndex(
|
||||
(n) => n === rangeEndImageName
|
||||
);
|
||||
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
|
||||
// We have a valid range!
|
||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||
|
||||
const imagesToSelect = state.imageNames.slice(start, end + 1);
|
||||
state.selection = uniq(state.selection.concat(imagesToSelect));
|
||||
}
|
||||
},
|
||||
batchImageSelectionToggled: (state, action: PayloadAction<string>) => {
|
||||
if (
|
||||
state.selection.includes(action.payload) &&
|
||||
state.selection.length > 1
|
||||
) {
|
||||
state.selection = state.selection.filter(
|
||||
(imageName) => imageName !== action.payload
|
||||
);
|
||||
} else {
|
||||
state.selection = uniq(state.selection.concat(action.payload));
|
||||
}
|
||||
},
|
||||
batchImageSelected: (state, action: PayloadAction<string | null>) => {
|
||||
state.selection = action.payload
|
||||
? [action.payload]
|
||||
: [String(state.imageNames[0])];
|
||||
},
|
||||
batchReset: (state) => {
|
||||
state.imageNames = [];
|
||||
state.selection = [];
|
||||
},
|
||||
asInitialImageToggled: (state) => {
|
||||
state.asInitialImage = !state.asInitialImage;
|
||||
},
|
||||
controlNetAddedToBatch: (state, action: PayloadAction<string>) => {
|
||||
state.controlNets = uniq(state.controlNets.concat(action.payload));
|
||||
},
|
||||
controlNetRemovedFromBatch: (state, action: PayloadAction<string>) => {
|
||||
state.controlNets = state.controlNets.filter(
|
||||
(controlNetId) => controlNetId !== action.payload
|
||||
);
|
||||
},
|
||||
controlNetToggled: (state, action: PayloadAction<string>) => {
|
||||
if (state.controlNets.includes(action.payload)) {
|
||||
state.controlNets = state.controlNets.filter(
|
||||
(controlNetId) => controlNetId !== action.payload
|
||||
);
|
||||
} else {
|
||||
state.controlNets = uniq(state.controlNets.concat(action.payload));
|
||||
}
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(imageDeleted.fulfilled, (state, action) => {
|
||||
state.imageNames = state.imageNames.filter(
|
||||
(imageName) => imageName !== action.meta.arg.image_name
|
||||
);
|
||||
state.selection = state.selection.filter(
|
||||
(imageName) => imageName !== action.meta.arg.image_name
|
||||
);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
isEnabledChanged,
|
||||
imageAddedToBatch,
|
||||
imagesAddedToBatch,
|
||||
imageRemovedFromBatch,
|
||||
imagesRemovedFromBatch,
|
||||
asInitialImageToggled,
|
||||
controlNetAddedToBatch,
|
||||
controlNetRemovedFromBatch,
|
||||
batchReset,
|
||||
controlNetToggled,
|
||||
batchImageRangeEndSelected,
|
||||
batchImageSelectionToggled,
|
||||
batchImageSelected,
|
||||
} = batch.actions;
|
||||
|
||||
export default batch.reducer;
|
||||
|
||||
export const selectionAddedToBatch = createAction(
|
||||
'batch/selectionAddedToBatch'
|
||||
);
|
@ -1,20 +1,22 @@
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/thunks/image';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetImageChanged,
|
||||
controlNetSelector,
|
||||
} from '../store/controlNetSlice';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const selector = createSelector(
|
||||
controlNetSelector,
|
||||
@ -57,22 +59,6 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
isSuccess: isSuccessProcessedControlImage,
|
||||
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (controlImageName === droppedImage.image_name) {
|
||||
return;
|
||||
}
|
||||
setIsMouseOverImage(false);
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: droppedImage.image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlImageName, controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
|
||||
}, [controlNetId, dispatch]);
|
||||
@ -84,6 +70,30 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
setIsMouseOverImage(false);
|
||||
}, []);
|
||||
|
||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||
if (controlImage) {
|
||||
return {
|
||||
id: controlNetId,
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO: controlImage },
|
||||
};
|
||||
}
|
||||
}, [controlImage, controlNetId]);
|
||||
|
||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||
() => ({
|
||||
id: controlNetId,
|
||||
actionType: 'SET_CONTROLNET_IMAGE',
|
||||
context: { controlNetId },
|
||||
}),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const postUploadAction = useMemo<PostUploadAction>(
|
||||
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const shouldShowProcessedImage =
|
||||
controlImage &&
|
||||
processedControlImage &&
|
||||
@ -104,14 +114,14 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={controlImage}
|
||||
onDrop={handleDrop}
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage}
|
||||
postUploadAction={{ type: 'SET_CONTROLNET_IMAGE', controlNetId }}
|
||||
imageSx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
onClickReset={handleResetControlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
/>
|
||||
<Box
|
||||
sx={{
|
||||
@ -127,14 +137,13 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={processedControlImage}
|
||||
onDrop={handleDrop}
|
||||
payloadImage={controlImage}
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
imageSx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
onClickReset={handleResetControlImage}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
/>
|
||||
</Box>
|
||||
{pendingControlImages.includes(controlNetId) && (
|
||||
@ -145,27 +154,12 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
insetInlineStart: 0,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
objectFit: 'contain',
|
||||
}}
|
||||
>
|
||||
<IAIImageLoadingFallback />
|
||||
<IAILoadingImageFallback image={controlImage} />
|
||||
</Box>
|
||||
)}
|
||||
{controlImage && (
|
||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0 }}>
|
||||
<IAIIconButton
|
||||
aria-label="Reset Control Image"
|
||||
tooltip="Reset Control Image"
|
||||
size="sm"
|
||||
onClick={handleResetControlImage}
|
||||
icon={<FaUndo />}
|
||||
variant="link"
|
||||
sx={{
|
||||
p: 2,
|
||||
color: 'base.50',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,36 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { isEnabled } = state.controlNet;
|
||||
|
||||
return { isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamControlNetFeatureToggle = () => {
|
||||
const { isEnabled } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = useCallback(() => {
|
||||
dispatch(isControlNetEnabledToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Enable ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamControlNetFeatureToggle;
|
@ -0,0 +1,15 @@
|
||||
import { filter } from 'lodash-es';
|
||||
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||
|
||||
export const getValidControlNets = (
|
||||
controlNets: Record<string, ControlNetConfig>
|
||||
) => {
|
||||
const validControlNets = filter(
|
||||
controlNets,
|
||||
(c) =>
|
||||
c.isEnabled &&
|
||||
(Boolean(c.processedControlImage) ||
|
||||
(c.processorType === 'none' && Boolean(c.controlImage)))
|
||||
);
|
||||
return validControlNets;
|
||||
};
|
@ -1,40 +1,30 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import { useCallback } from 'react';
|
||||
import { isEnabledToggled } from '../store/slice';
|
||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { isEnabled } = state.dynamicPrompts;
|
||||
|
||||
return { isEnabled };
|
||||
return { activeLabel: isEnabled ? 'Enabled' : undefined };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamDynamicPromptsCollapse = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { isEnabled } = useAppSelector(selector);
|
||||
|
||||
const handleToggleIsEnabled = useCallback(() => {
|
||||
dispatch(isEnabledToggled());
|
||||
}, [dispatch]);
|
||||
const { activeLabel } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<IAICollapse
|
||||
isOpen={isEnabled}
|
||||
onToggle={handleToggleIsEnabled}
|
||||
label="Dynamic Prompts"
|
||||
withSwitch
|
||||
>
|
||||
<IAICollapse label="Dynamic Prompts" activeLabel={activeLabel}>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<ParamDynamicPromptsToggle />
|
||||
<ParamDynamicPromptsCombinatorial />
|
||||
<ParamDynamicPromptsMaxPrompts />
|
||||
</Flex>
|
||||
|
@ -1,23 +1,23 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { combinatorialToggled } from '../store/slice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useCallback } from 'react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useCallback } from 'react';
|
||||
import { combinatorialToggled } from '../store/slice';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { combinatorial } = state.dynamicPrompts;
|
||||
const { combinatorial, isEnabled } = state.dynamicPrompts;
|
||||
|
||||
return { combinatorial };
|
||||
return { combinatorial, isDisabled: !isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamDynamicPromptsCombinatorial = () => {
|
||||
const { combinatorial } = useAppSelector(selector);
|
||||
const { combinatorial, isDisabled } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = useCallback(() => {
|
||||
@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => {
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
isDisabled={isDisabled}
|
||||
label="Combinatorial Generation"
|
||||
isChecked={combinatorial}
|
||||
onChange={handleChange}
|
||||
|
@ -0,0 +1,36 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useCallback } from 'react';
|
||||
import { isEnabledToggled } from '../store/slice';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { isEnabled } = state.dynamicPrompts;
|
||||
|
||||
return { isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamDynamicPromptsToggle = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { isEnabled } = useAppSelector(selector);
|
||||
|
||||
const handleToggleIsEnabled = useCallback(() => {
|
||||
dispatch(isEnabledToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Enable Dynamic Prompts"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleToggleIsEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamDynamicPromptsToggle;
|
@ -1,25 +1,31 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useCallback } from 'react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { useCallback } from 'react';
|
||||
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { maxPrompts, combinatorial } = state.dynamicPrompts;
|
||||
const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts;
|
||||
const { min, sliderMax, inputMax } =
|
||||
state.config.sd.dynamicPrompts.maxPrompts;
|
||||
|
||||
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
|
||||
return {
|
||||
maxPrompts,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
isDisabled: !isEnabled || !combinatorial,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamDynamicPromptsMaxPrompts = () => {
|
||||
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
|
||||
const { maxPrompts, min, sliderMax, inputMax, isDisabled } =
|
||||
useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
|
||||
return (
|
||||
<IAISlider
|
||||
label="Max Prompts"
|
||||
isDisabled={!combinatorial}
|
||||
isDisabled={isDisabled}
|
||||
min={min}
|
||||
max={sliderMax}
|
||||
value={maxPrompts}
|
||||
|
@ -0,0 +1,33 @@
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { memo } from 'react';
|
||||
import { BiCode } from 'react-icons/bi';
|
||||
|
||||
type Props = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
const AddEmbeddingButton = (props: Props) => {
|
||||
const { onClick } = props;
|
||||
return (
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label="Add Embedding"
|
||||
tooltip="Add Embedding"
|
||||
icon={<BiCode />}
|
||||
sx={{
|
||||
p: 2,
|
||||
color: 'base.700',
|
||||
_hover: {
|
||||
color: 'base.550',
|
||||
},
|
||||
_active: {
|
||||
color: 'base.500',
|
||||
},
|
||||
}}
|
||||
variant="link"
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddEmbeddingButton);
|
@ -0,0 +1,151 @@
|
||||
import {
|
||||
Flex,
|
||||
Popover,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
PropsWithChildren,
|
||||
forwardRef,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||
|
||||
type EmbeddingSelectItem = {
|
||||
label: string;
|
||||
value: string;
|
||||
description?: string;
|
||||
};
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
onSelect: (v: string) => void;
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
};
|
||||
|
||||
const ParamEmbeddingPopover = (props: Props) => {
|
||||
const { onSelect, isOpen, onClose, children } = props;
|
||||
const { data: embeddingQueryData } = useGetTextualInversionModelsQuery();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!embeddingQueryData) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: EmbeddingSelectItem[] = [];
|
||||
|
||||
forEach(embeddingQueryData.entities, (embedding, _) => {
|
||||
if (!embedding) return;
|
||||
|
||||
data.push({
|
||||
value: embedding.name,
|
||||
label: embedding.name,
|
||||
description: embedding.description,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [embeddingQueryData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
if (v.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
onSelect(v[0]);
|
||||
},
|
||||
[onSelect]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
initialFocusRef={inputRef}
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
closeOnBlur={true}
|
||||
returnFocusOnClose={true}
|
||||
>
|
||||
<PopoverTrigger>{children}</PopoverTrigger>
|
||||
<PopoverContent
|
||||
sx={{
|
||||
p: 0,
|
||||
top: -1,
|
||||
shadow: 'dark-lg',
|
||||
borderColor: 'accent.300',
|
||||
borderWidth: '2px',
|
||||
borderStyle: 'solid',
|
||||
_dark: { borderColor: 'accent.400' },
|
||||
}}
|
||||
>
|
||||
<PopoverBody
|
||||
sx={{ p: 0, w: `calc(${PARAMETERS_PANEL_WIDTH} - 2rem )` }}
|
||||
>
|
||||
{data.length === 0 ? (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text
|
||||
sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}
|
||||
>
|
||||
No Embeddings Loaded
|
||||
</Text>
|
||||
</Flex>
|
||||
) : (
|
||||
<IAIMantineMultiSelect
|
||||
inputRef={inputRef}
|
||||
placeholder={'Add Embedding'}
|
||||
value={[]}
|
||||
data={data}
|
||||
maxDropdownHeight={400}
|
||||
nothingFound="No Matching Embeddings"
|
||||
itemComponent={SelectItem}
|
||||
disabled={data.length === 0}
|
||||
filter={(value, selected, item: EmbeddingSelectItem) =>
|
||||
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||
}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
)}
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamEmbeddingPopover;
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, description, ...others }: ItemProps, ref) => {
|
||||
return (
|
||||
<div ref={ref} {...others}>
|
||||
<div>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
SelectItem.displayName = 'SelectItem';
|
@ -1,16 +1,16 @@
|
||||
import { Flex, Text, useColorMode } from '@chakra-ui/react';
|
||||
import { Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { FaImages } from 'react-icons/fa';
|
||||
import { boardIdSelected } from '../../store/boardSlice';
|
||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { useDispatch } from 'react-redux';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
import { useCallback } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import {
|
||||
MoveBoardDropData,
|
||||
isValidDrop,
|
||||
useDroppable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
|
||||
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
const dispatch = useDispatch();
|
||||
@ -20,31 +20,15 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
dispatch(boardIdSelected());
|
||||
};
|
||||
|
||||
const [removeImageFromBoard, { isLoading }] =
|
||||
useRemoveImageFromBoardMutation();
|
||||
const droppableData: MoveBoardDropData = {
|
||||
id: 'all-images-board',
|
||||
actionType: 'MOVE_BOARD',
|
||||
context: { boardId: null },
|
||||
};
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (!droppedImage.board_id) {
|
||||
return;
|
||||
}
|
||||
removeImageFromBoard({
|
||||
board_id: droppedImage.board_id,
|
||||
image_name: droppedImage.image_name,
|
||||
});
|
||||
},
|
||||
[removeImageFromBoard]
|
||||
);
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
const { isOver, setNodeRef, active } = useDroppable({
|
||||
id: `board_droppable_all_images`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
data: droppableData,
|
||||
});
|
||||
|
||||
return (
|
||||
@ -58,10 +42,10 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
h: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
onClick={handleAllImagesBoardClick}
|
||||
>
|
||||
<Flex
|
||||
ref={setNodeRef}
|
||||
onClick={handleAllImagesBoardClick}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
@ -69,18 +53,30 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
|
||||
<IAINoContentFallback
|
||||
boxSize={8}
|
||||
icon={FaImages}
|
||||
sx={{
|
||||
border: '2px solid var(--invokeai-colors-base-200)',
|
||||
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
|
||||
}}
|
||||
/>
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
{isValidDrop(droppableData, active) && (
|
||||
<IAIDropOverlay isOver={isOver} />
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
<Text
|
||||
<Flex
|
||||
sx={{
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
color: isSelected
|
||||
? mode('base.900', 'base.50')(colorMode)
|
||||
: mode('base.700', 'base.200')(colorMode),
|
||||
@ -89,7 +85,7 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
}}
|
||||
>
|
||||
All Images
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -2,6 +2,7 @@ import {
|
||||
Collapse,
|
||||
Flex,
|
||||
Grid,
|
||||
GridItem,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
@ -10,10 +11,7 @@ import {
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
boardsSelector,
|
||||
setBoardSearchText,
|
||||
} from 'features/gallery/store/boardSlice';
|
||||
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
|
||||
import { memo, useState } from 'react';
|
||||
import HoverableBoard from './HoverableBoard';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
@ -21,11 +19,13 @@ import AddBoardButton from './AddBoardButton';
|
||||
import AllImagesBoard from './AllImagesBoard';
|
||||
import { CloseIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
|
||||
const selector = createSelector(
|
||||
[boardsSelector],
|
||||
(boardsState) => {
|
||||
const { selectedBoardId, searchText } = boardsState;
|
||||
[stateSelector],
|
||||
({ boards, gallery }) => {
|
||||
const { searchText } = boards;
|
||||
const { selectedBoardId } = gallery;
|
||||
return { selectedBoardId, searchText };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -109,20 +109,24 @@ const BoardsList = (props: Props) => {
|
||||
<Grid
|
||||
className="list-container"
|
||||
sx={{
|
||||
gap: 2,
|
||||
gridTemplateRows: '5.5rem 5.5rem',
|
||||
gridTemplateRows: '6.5rem 6.5rem',
|
||||
gridAutoFlow: 'column dense',
|
||||
gridAutoColumns: '4rem',
|
||||
gridAutoColumns: '5rem',
|
||||
}}
|
||||
>
|
||||
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
|
||||
{!searchMode && (
|
||||
<GridItem sx={{ p: 1.5 }}>
|
||||
<AllImagesBoard isSelected={!selectedBoardId} />
|
||||
</GridItem>
|
||||
)}
|
||||
{filteredBoards &&
|
||||
filteredBoards.map((board) => (
|
||||
<HoverableBoard
|
||||
key={board.board_id}
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
/>
|
||||
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
|
||||
<HoverableBoard
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
/>
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
|
@ -15,10 +15,9 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { memo, useCallback, useContext } from 'react';
|
||||
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import { BoardDTO, ImageDTO } from 'services/api/types';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import { useAddImageToBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
useDeleteBoardMutation,
|
||||
useUpdateBoardMutation,
|
||||
@ -26,12 +25,15 @@ import {
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import {
|
||||
MoveBoardDropData,
|
||||
isValidDrop,
|
||||
useDroppable,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
|
||||
interface HoverableBoardProps {
|
||||
board: BoardDTO;
|
||||
@ -61,9 +63,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
|
||||
useDeleteBoardMutation();
|
||||
|
||||
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
|
||||
useAddImageToBoardMutation();
|
||||
|
||||
const handleUpdateBoardName = (newBoardName: string) => {
|
||||
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
||||
};
|
||||
@ -77,29 +76,19 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
onClickDeleteBoardImages(board);
|
||||
}, [board, onClickDeleteBoardImages]);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (droppedImage.board_id === board_id) {
|
||||
return;
|
||||
}
|
||||
addImageToBoard({ board_id, image_name: droppedImage.image_name });
|
||||
},
|
||||
[addImageToBoard, board_id]
|
||||
);
|
||||
const droppableData: MoveBoardDropData = {
|
||||
id: board_id,
|
||||
actionType: 'MOVE_BOARD',
|
||||
context: { boardId: board_id },
|
||||
};
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
const { isOver, setNodeRef, active } = useDroppable({
|
||||
id: `board_droppable_${board_id}`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
data: droppableData,
|
||||
});
|
||||
|
||||
return (
|
||||
<Box sx={{ touchAction: 'none' }}>
|
||||
<Box sx={{ touchAction: 'none', height: 'full' }}>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
@ -148,13 +137,25 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
shadow: isSelected ? 'selected.light' : undefined,
|
||||
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
{board.cover_image_name && coverImage?.image_url && (
|
||||
<Image src={coverImage?.image_url} draggable={false} />
|
||||
)}
|
||||
{!(board.cover_image_name && coverImage?.image_url) && (
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
|
||||
<IAINoContentFallback
|
||||
boxSize={8}
|
||||
icon={FaFolder}
|
||||
sx={{
|
||||
border: '2px solid var(--invokeai-colors-base-200)',
|
||||
_dark: {
|
||||
border: '2px solid var(--invokeai-colors-base-800)',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<Flex
|
||||
sx={{
|
||||
@ -167,14 +168,20 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
<Badge variant="solid">{board.image_count}</Badge>
|
||||
</Flex>
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
{isValidDrop(droppableData, active) && (
|
||||
<IAIDropOverlay isOver={isOver} />
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
|
||||
<Box sx={{ width: 'full' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Editable
|
||||
defaultValue={board_name}
|
||||
submitOnBlur={false}
|
||||
@ -204,7 +211,7 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
}}
|
||||
/>
|
||||
</Editable>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</ContextMenu>
|
||||
|
@ -38,8 +38,7 @@ import {
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
} from 'react-icons/fa';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import { useCallback, useContext } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
@ -49,22 +48,15 @@ import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceR
|
||||
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { DeleteImageButton } from './DeleteImageModal';
|
||||
import { selectImagesById } from '../store/imagesSlice';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state,
|
||||
systemSelector,
|
||||
gallerySelector,
|
||||
postprocessingSelector,
|
||||
uiSelector,
|
||||
lightboxSelector,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ gallery, system, postprocessing, ui, lightbox }, activeTabName) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
@ -84,9 +76,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
|
||||
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
|
||||
|
||||
const { selectedImage } = gallery;
|
||||
const lastSelectedImage = gallery.selection[gallery.selection.length - 1];
|
||||
|
||||
return {
|
||||
canDeleteImage: isConnected && !isProcessing,
|
||||
@ -97,16 +87,13 @@ const currentImageButtonsSelector = createSelector(
|
||||
isESRGANAvailable,
|
||||
upscalingLevel,
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons: Boolean(progressImage) || !selectedImage,
|
||||
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
shouldHidePreview,
|
||||
image: imageDTO,
|
||||
seed: imageDTO?.metadata?.seed,
|
||||
prompt: imageDTO?.metadata?.positive_conditioning,
|
||||
negativePrompt: imageDTO?.metadata?.negative_conditioning,
|
||||
shouldShowProgressInViewer,
|
||||
lastSelectedImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -132,7 +119,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
isLightboxOpen,
|
||||
activeTabName,
|
||||
shouldHidePreview,
|
||||
image,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
@ -147,7 +134,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const { onDelete } = useContext(DeleteImageContext);
|
||||
const { currentData: image } = useGetImageDTOQuery(
|
||||
lastSelectedImage ?? skipToken
|
||||
);
|
||||
|
||||
// const handleCopyImage = useCallback(async () => {
|
||||
// if (!image?.url) {
|
||||
@ -248,8 +237,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}, []);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
onDelete(image);
|
||||
}, [image, onDelete]);
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
useHotkeys(
|
||||
'Shift+U',
|
||||
@ -371,7 +363,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}}
|
||||
{...props}
|
||||
>
|
||||
<ButtonGroup isAttached={true}>
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
@ -444,11 +436,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
isChecked={isLightboxOpen}
|
||||
onClick={handleLightBox}
|
||||
isDisabled={shouldDisableToolbarButtons}
|
||||
/>
|
||||
)}
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIIconButton
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
@ -478,7 +471,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</ButtonGroup>
|
||||
|
||||
{(isUpscalingEnabled || isFaceRestoreEnabled) && (
|
||||
<ButtonGroup isAttached={true}>
|
||||
<ButtonGroup
|
||||
isAttached={true}
|
||||
isDisabled={shouldDisableToolbarButtons}
|
||||
>
|
||||
{isFaceRestoreEnabled && (
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
@ -543,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</ButtonGroup>
|
||||
)}
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIIconButton
|
||||
icon={<FaCode />}
|
||||
tooltip={`${t('parameters.info')} (I)`}
|
||||
@ -553,7 +549,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIIconButton
|
||||
aria-label={t('settings.displayInProgress')}
|
||||
tooltip={t('settings.displayInProgress')}
|
||||
@ -564,7 +560,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<DeleteImageButton onClick={handleDelete} />
|
||||
<DeleteImageButton
|
||||
onClick={handleDelete}
|
||||
isDisabled={shouldDisableToolbarButtons}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
</>
|
||||
|
@ -1,29 +1,9 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
|
||||
export const currentImageDisplaySelector = createSelector(
|
||||
[systemSelector, gallerySelector],
|
||||
(system, gallery) => {
|
||||
const { progressImage } = system;
|
||||
|
||||
return {
|
||||
hasSelectedImage: Boolean(gallery.selectedImage),
|
||||
hasProgressImage: Boolean(progressImage),
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const CurrentImageDisplay = () => {
|
||||
const { hasSelectedImage } = useAppSelector(currentImageDisplaySelector);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
@ -36,7 +16,7 @@ const CurrentImageDisplay = () => {
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
>
|
||||
{hasSelectedImage && <CurrentImageButtons />}
|
||||
<CurrentImageButtons />
|
||||
<CurrentImagePreview />
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,35 +1,33 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { imageSelected } from '../store/gallerySlice';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, gallerySelector, systemSelector],
|
||||
(ui, gallery, system) => {
|
||||
[stateSelector, selectLastSelectedImage],
|
||||
({ ui, system }, lastSelectedImage) => {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
const { selectedImage } = gallery;
|
||||
const { progressImage, shouldAntialiasProgressImage } = system;
|
||||
return {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
selectedImage,
|
||||
imageName: lastSelectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
@ -45,29 +43,35 @@ export const imagesSelector = createSelector(
|
||||
const CurrentImagePreview = () => {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
selectedImage,
|
||||
imageName,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
const {
|
||||
currentData: image,
|
||||
currentData: imageDTO,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(selectedImage ?? skipToken);
|
||||
} = useGetImageDTOQuery(imageName ?? skipToken);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||
if (imageDTO) {
|
||||
return {
|
||||
id: 'current-image',
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO },
|
||||
};
|
||||
}
|
||||
}, [imageDTO]);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (droppedImage.image_name === image?.image_name) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(droppedImage.image_name));
|
||||
},
|
||||
[dispatch, image?.image_name]
|
||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||
() => ({
|
||||
id: 'current-image',
|
||||
actionType: 'SET_CURRENT_IMAGE',
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
@ -98,14 +102,15 @@ const CurrentImagePreview = () => {
|
||||
/>
|
||||
) : (
|
||||
<IAIDndImage
|
||||
image={image}
|
||||
onDrop={handleDrop}
|
||||
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||
imageDTO={imageDTO}
|
||||
droppableData={droppableData}
|
||||
draggableData={draggableData}
|
||||
isUploadDisabled={true}
|
||||
fitContainer
|
||||
dropLabel="Set as Current Image"
|
||||
/>
|
||||
)}
|
||||
{shouldShowImageDetails && image && (
|
||||
{shouldShowImageDetails && imageDTO && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -116,10 +121,10 @@ const CurrentImagePreview = () => {
|
||||
overflow: 'scroll',
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={image} />
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
</Box>
|
||||
)}
|
||||
{!shouldShowImageDetails && image && (
|
||||
{!shouldShowImageDetails && imageDTO && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
|
@ -1,166 +0,0 @@
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogContent,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogOverlay,
|
||||
Divider,
|
||||
Flex,
|
||||
ListItem,
|
||||
Text,
|
||||
UnorderedList,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import {
|
||||
DeleteImageContext,
|
||||
ImageUsage,
|
||||
} from 'app/contexts/DeleteImageContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
|
||||
import { some } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent, memo, useCallback, useContext, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector, configSelector],
|
||||
(system, config) => {
|
||||
const { shouldConfirmOnDelete } = system;
|
||||
const { canRestoreDeletedImagesFromBin } = config;
|
||||
|
||||
return {
|
||||
shouldConfirmOnDelete,
|
||||
canRestoreDeletedImagesFromBin,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ImageInUseMessage = (props: { imageUsage?: ImageUsage }) => {
|
||||
const { imageUsage } = props;
|
||||
|
||||
if (!imageUsage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!some(imageUsage)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Text>This image is currently in use in the following features:</Text>
|
||||
<UnorderedList sx={{ paddingInlineStart: 6 }}>
|
||||
{imageUsage.isInitialImage && <ListItem>Image to Image</ListItem>}
|
||||
{imageUsage.isCanvasImage && <ListItem>Unified Canvas</ListItem>}
|
||||
{imageUsage.isControlNetImage && <ListItem>ControlNet</ListItem>}
|
||||
{imageUsage.isNodesImage && <ListItem>Node Editor</ListItem>}
|
||||
</UnorderedList>
|
||||
<Text>
|
||||
If you delete this image, those features will immediately be reset.
|
||||
</Text>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const DeleteImageModal = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { isOpen, onClose, onImmediatelyDelete, image, imageUsage } =
|
||||
useContext(DeleteImageContext);
|
||||
|
||||
const { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const handleChangeShouldConfirmOnDelete = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldConfirmOnDelete(!e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
return (
|
||||
<AlertDialog
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{t('gallery.deleteImage')}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Flex direction="column" gap={3}>
|
||||
<ImageInUseMessage imageUsage={imageUsage} />
|
||||
<Divider />
|
||||
<Text>
|
||||
{canRestoreDeletedImagesFromBin
|
||||
? t('gallery.deleteImageBin')
|
||||
: t('gallery.deleteImagePermanent')}
|
||||
</Text>
|
||||
<Text>{t('common.areYouSure')}</Text>
|
||||
<IAISwitch
|
||||
label={t('common.dontAskMeAgain')}
|
||||
isChecked={!shouldConfirmOnDelete}
|
||||
onChange={handleChangeShouldConfirmOnDelete}
|
||||
/>
|
||||
</Flex>
|
||||
</AlertDialogBody>
|
||||
<AlertDialogFooter>
|
||||
<IAIButton ref={cancelRef} onClick={onClose}>
|
||||
Cancel
|
||||
</IAIButton>
|
||||
<IAIButton colorScheme="error" onClick={onImmediatelyDelete} ml={3}>
|
||||
Delete
|
||||
</IAIButton>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(DeleteImageModal);
|
||||
|
||||
const deleteImageButtonsSelector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { isProcessing, isConnected } = system;
|
||||
|
||||
return isConnected && !isProcessing;
|
||||
}
|
||||
);
|
||||
|
||||
type DeleteImageButtonProps = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
export const DeleteImageButton = (props: DeleteImageButtonProps) => {
|
||||
const { onClick } = props;
|
||||
const { t } = useTranslation();
|
||||
const canDeleteImage = useAppSelector(deleteImageButtonsSelector);
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
onClick={onClick}
|
||||
icon={<FaTrash />}
|
||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
||||
isDisabled={!canDeleteImage}
|
||||
colorScheme="error"
|
||||
/>
|
||||
);
|
||||
};
|
@ -0,0 +1,132 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
imageRangeEndSelected,
|
||||
imageSelected,
|
||||
imageSelectionToggled,
|
||||
} from '../store/gallerySlice';
|
||||
import ImageContextMenu from './ImageContextMenu';
|
||||
|
||||
export const makeSelector = (image_name: string) =>
|
||||
createSelector(
|
||||
[stateSelector],
|
||||
({ gallery }) => {
|
||||
const isSelected = gallery.selection.includes(image_name);
|
||||
const selectionCount = gallery.selection.length;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
selectionCount,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
*/
|
||||
const GalleryImage = (props: HoverableImageProps) => {
|
||||
const { imageDTO } = props;
|
||||
const { image_url, thumbnail_url, image_name } = imageDTO;
|
||||
|
||||
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
|
||||
|
||||
const { isSelected, selectionCount } = useAppSelector(localSelector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(imageRangeEndSelected(props.imageDTO.image_name));
|
||||
} else if (e.ctrlKey || e.metaKey) {
|
||||
dispatch(imageSelectionToggled(props.imageDTO.image_name));
|
||||
} else {
|
||||
dispatch(imageSelected(props.imageDTO.image_name));
|
||||
}
|
||||
},
|
||||
[dispatch, props.imageDTO.image_name]
|
||||
);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(imageDTO));
|
||||
},
|
||||
[dispatch, imageDTO]
|
||||
);
|
||||
|
||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||
if (selectionCount > 1) {
|
||||
return {
|
||||
id: 'gallery-image',
|
||||
payloadType: 'GALLERY_SELECTION',
|
||||
};
|
||||
}
|
||||
|
||||
if (imageDTO) {
|
||||
return {
|
||||
id: 'gallery-image',
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO },
|
||||
};
|
||||
}
|
||||
}, [imageDTO, selectionCount]);
|
||||
|
||||
return (
|
||||
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
|
||||
<ImageContextMenu image={imageDTO}>
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={image_name}
|
||||
userSelect="none"
|
||||
ref={ref}
|
||||
sx={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
onClick={handleClick}
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
minSize={0}
|
||||
onClickReset={handleDelete}
|
||||
resetIcon={<FaTrash />}
|
||||
resetTooltip="Delete image"
|
||||
imageSx={{ w: 'full', h: 'full' }}
|
||||
// withResetIcon // removed bc it's too easy to accidentally delete images
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</ImageContextMenu>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(GalleryImage);
|
@ -1,371 +0,0 @@
|
||||
import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useContext, useState } from 'react';
|
||||
import {
|
||||
FaCheck,
|
||||
FaExpand,
|
||||
FaFolder,
|
||||
FaImage,
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useDraggable } from '@dnd-kit/core';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
|
||||
export const selector = createSelector(
|
||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||
(gallery, system, lightbox, activeTabName) => {
|
||||
const {
|
||||
galleryImageObjectFit,
|
||||
galleryImageMinimumWidth,
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = gallery;
|
||||
|
||||
const { isLightboxOpen } = lightbox;
|
||||
const { isConnected, isProcessing, shouldConfirmOnDelete } = system;
|
||||
|
||||
return {
|
||||
canDeleteImage: isConnected && !isProcessing,
|
||||
shouldConfirmOnDelete,
|
||||
galleryImageObjectFit,
|
||||
galleryImageMinimumWidth,
|
||||
shouldUseSingleGalleryColumn,
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
interface HoverableImageProps {
|
||||
image: ImageDTO;
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
*/
|
||||
const HoverableImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
activeTabName,
|
||||
galleryImageObjectFit,
|
||||
galleryImageMinimumWidth,
|
||||
canDeleteImage,
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const { image, isSelected } = props;
|
||||
const { image_url, thumbnail_url, image_name } = image;
|
||||
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onDelete } = useContext(DeleteImageContext);
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
const handleDelete = useCallback(() => {
|
||||
onDelete(image);
|
||||
}, [image, onDelete]);
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const { attributes, listeners, setNodeRef } = useDraggable({
|
||||
id: `galleryImage_${image_name}`,
|
||||
data: {
|
||||
image,
|
||||
},
|
||||
});
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleSelectImage = useCallback(() => {
|
||||
dispatch(imageSelected(image.image_name));
|
||||
}, [image, dispatch]);
|
||||
|
||||
// Recall parameters handlers
|
||||
const handleRecallPrompt = useCallback(() => {
|
||||
recallBothPrompts(
|
||||
image.metadata?.positive_conditioning,
|
||||
image.metadata?.negative_conditioning
|
||||
);
|
||||
}, [
|
||||
image.metadata?.negative_conditioning,
|
||||
image.metadata?.positive_conditioning,
|
||||
recallBothPrompts,
|
||||
]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(image.metadata?.seed);
|
||||
}, [image, recallSeed]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
// const handleRecallInitialImage = useCallback(() => {
|
||||
// recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||
// }, [image, recallInitialImage]);
|
||||
|
||||
/**
|
||||
* TODO: the rest of these
|
||||
*/
|
||||
const handleSendToCanvas = () => {
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
|
||||
if (activeTabName !== 'unifiedCanvas') {
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
}
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
ref={setNodeRef}
|
||||
{...listeners}
|
||||
{...attributes}
|
||||
sx={{ w: 'full', h: 'full', touchAction: 'none' }}
|
||||
>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
<MenuItem
|
||||
icon={<ExternalLinkIcon />}
|
||||
onClickCapture={handleOpenInNewTab}
|
||||
>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
{isLightboxEnabled && (
|
||||
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
|
||||
{t('parameters.openInViewer')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={image?.metadata?.positive_conditioning === undefined}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={image?.metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
{/* <MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallInitialImage}
|
||||
isDisabled={image?.metadata?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem> */}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
// what should these be
|
||||
!['t2l', 'l2l', 'inpaint'].includes(
|
||||
String(image?.metadata?.type)
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToImageToImage}
|
||||
id="send-to-img2img"
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</MenuItem>
|
||||
{isCanvasEnabled && (
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToCanvas}
|
||||
id="send-to-canvas"
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleRemoveFromBoard}
|
||||
>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.300' }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDelete}
|
||||
>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={image_name}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
userSelect="none"
|
||||
onClick={handleSelectImage}
|
||||
ref={ref}
|
||||
sx={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
transition: 'transform 0.2s ease-out',
|
||||
aspectRatio: '1/1',
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
loading="lazy"
|
||||
objectFit={
|
||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||
}
|
||||
draggable={false}
|
||||
rounded="md"
|
||||
src={thumbnail_url || image_url}
|
||||
fallback={<FaImage />}
|
||||
sx={{
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
}}
|
||||
/>
|
||||
{isSelected && (
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
insetInlineStart: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
filter={'drop-shadow(0px 0px 1rem black)'}
|
||||
as={FaCheck}
|
||||
sx={{
|
||||
width: '50%',
|
||||
height: '50%',
|
||||
maxWidth: '4rem',
|
||||
maxHeight: '4rem',
|
||||
fill: 'ok.500',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{isHovered && galleryImageMinimumWidth >= 100 && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
onClickCapture={handleDelete}
|
||||
aria-label={t('gallery.deleteImage')}
|
||||
icon={<FaTrash />}
|
||||
size="xs"
|
||||
fontSize={14}
|
||||
isDisabled={!canDeleteImage}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(HoverableImage);
|
@ -0,0 +1,278 @@
|
||||
import { MenuItem, MenuList } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useCallback, useContext } from 'react';
|
||||
import {
|
||||
FaExpand,
|
||||
FaFolder,
|
||||
FaFolderPlus,
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { RootState, stateSelector } from 'app/store/store';
|
||||
import {
|
||||
imagesAddedToBatch,
|
||||
selectionAddedToBatch,
|
||||
} from 'features/batch/store/batchSlice';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector, (state: RootState, imageDTO: ImageDTO) => imageDTO],
|
||||
({ gallery, batch }, imageDTO) => {
|
||||
const selectionCount = gallery.selection.length;
|
||||
const isInBatch = batch.imageNames.includes(imageDTO.image_name);
|
||||
|
||||
return { selectionCount, isInBatch };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type Props = {
|
||||
image: ImageDTO;
|
||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||
};
|
||||
|
||||
const ImageContextMenu = ({ image, children }: Props) => {
|
||||
const { selectionCount, isInBatch } = useAppSelector((state) =>
|
||||
selector(state, image)
|
||||
);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageToDeleteSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||
useRecallParameters();
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
// Recall parameters handlers
|
||||
const handleRecallPrompt = useCallback(() => {
|
||||
recallBothPrompts(
|
||||
image.metadata?.positive_conditioning,
|
||||
image.metadata?.negative_conditioning
|
||||
);
|
||||
}, [
|
||||
image.metadata?.negative_conditioning,
|
||||
image.metadata?.positive_conditioning,
|
||||
recallBothPrompts,
|
||||
]);
|
||||
|
||||
const handleRecallSeed = useCallback(() => {
|
||||
recallSeed(image.metadata?.seed);
|
||||
}, [image, recallSeed]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
dispatch(sentImageToImg2Img());
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
// const handleRecallInitialImage = useCallback(() => {
|
||||
// recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||
// }, [image, recallInitialImage]);
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
|
||||
toaster({
|
||||
title: t('toast.sentToUnifiedCanvas'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
};
|
||||
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
|
||||
const handleAddSelectionToBatch = useCallback(() => {
|
||||
dispatch(selectionAddedToBatch());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleAddToBatch = useCallback(() => {
|
||||
dispatch(imagesAddedToBatch([image.image_name]));
|
||||
}, [dispatch, image.image_name]);
|
||||
|
||||
return (
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
{selectionCount === 1 ? (
|
||||
<>
|
||||
<MenuItem
|
||||
icon={<ExternalLinkIcon />}
|
||||
onClickCapture={handleOpenInNewTab}
|
||||
>
|
||||
{t('common.openInNewTab')}
|
||||
</MenuItem>
|
||||
{isLightboxEnabled && (
|
||||
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
|
||||
{t('parameters.openInViewer')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
image?.metadata?.positive_conditioning === undefined
|
||||
}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={image?.metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
{/* <MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleRecallInitialImage}
|
||||
isDisabled={image?.metadata?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem> */}
|
||||
<MenuItem
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
// what should these be
|
||||
!['t2l', 'l2l', 'inpaint'].includes(
|
||||
String(image?.metadata?.type)
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToImageToImage}
|
||||
id="send-to-img2img"
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</MenuItem>
|
||||
{isCanvasEnabled && (
|
||||
<MenuItem
|
||||
icon={<FaShare />}
|
||||
onClickCapture={handleSendToCanvas}
|
||||
id="send-to-canvas"
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{/* <MenuItem
|
||||
icon={<FaFolder />}
|
||||
isDisabled={isInBatch}
|
||||
onClickCapture={handleAddToBatch}
|
||||
>
|
||||
Add to Batch
|
||||
</MenuItem> */}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleRemoveFromBoard}
|
||||
>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDelete}
|
||||
>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<MenuItem
|
||||
isDisabled={true}
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleAddToBoard}
|
||||
>
|
||||
Move Selection to Board
|
||||
</MenuItem>
|
||||
{/* <MenuItem
|
||||
icon={<FaFolderPlus />}
|
||||
onClickCapture={handleAddSelectionToBatch}
|
||||
>
|
||||
Add Selection to Batch
|
||||
</MenuItem> */}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDelete}
|
||||
>
|
||||
Delete Selection
|
||||
</MenuItem>
|
||||
</>
|
||||
)}
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</ContextMenu>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImageContextMenu);
|
@ -5,7 +5,7 @@ import {
|
||||
Flex,
|
||||
FlexProps,
|
||||
Grid,
|
||||
Icon,
|
||||
Skeleton,
|
||||
Text,
|
||||
VStack,
|
||||
forwardRef,
|
||||
@ -18,12 +18,8 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
setGalleryImageMinimumWidth,
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setGalleryView,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
||||
@ -42,77 +38,56 @@ import {
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
|
||||
import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
|
||||
import { MdPhotoLibrary } from 'react-icons/md';
|
||||
import HoverableImage from './HoverableImage';
|
||||
import GalleryImage from './GalleryImage';
|
||||
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||
import { RootState, stateSelector } from 'app/store/store';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
IMAGE_CATEGORIES,
|
||||
imageCategoriesChanged,
|
||||
selectImagesAll,
|
||||
} from '../store/imagesSlice';
|
||||
shouldAutoSwitchChanged,
|
||||
selectFilteredImages,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import BoardsList from './Boards/BoardsList';
|
||||
import { boardsSelector } from '../store/boardSlice';
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
|
||||
const itemSelector = createSelector(
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const { categories, total: allImagesTotal, isLoading } = state.images;
|
||||
const { selectedBoardId } = state.boards;
|
||||
const LOADING_IMAGE_ARRAY = Array(20).fill('loading');
|
||||
|
||||
const allImages = selectImagesAll(state);
|
||||
const selector = createSelector(
|
||||
[stateSelector, selectFilteredImages],
|
||||
(state, filteredImages) => {
|
||||
const {
|
||||
categories,
|
||||
total: allImagesTotal,
|
||||
isLoading,
|
||||
selectedBoardId,
|
||||
galleryImageMinimumWidth,
|
||||
galleryView,
|
||||
shouldAutoSwitch,
|
||||
} = state.gallery;
|
||||
const { shouldPinGallery } = state.ui;
|
||||
|
||||
const images = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = selectedBoardId
|
||||
? i.board_id === selectedBoardId
|
||||
: true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
const images = filteredImages as (ImageDTO | string)[];
|
||||
|
||||
return {
|
||||
images,
|
||||
images: isLoading ? images.concat(LOADING_IMAGE_ARRAY) : images,
|
||||
allImagesTotal,
|
||||
isLoading,
|
||||
categories,
|
||||
selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const mainSelector = createSelector(
|
||||
[gallerySelector, uiSelector, boardsSelector],
|
||||
(gallery, ui, boards) => {
|
||||
const {
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
} = gallery;
|
||||
|
||||
const { shouldPinGallery } = ui;
|
||||
return {
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
shouldAutoSwitch,
|
||||
galleryView,
|
||||
selectedBoardId: boards.selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -140,17 +115,16 @@ const ImageGalleryContent = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const {
|
||||
images,
|
||||
isLoading,
|
||||
allImagesTotal,
|
||||
categories,
|
||||
selectedBoardId,
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
shouldAutoSwitch,
|
||||
galleryView,
|
||||
} = useAppSelector(mainSelector);
|
||||
|
||||
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
|
||||
useAppSelector(itemSelector);
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const { selectedBoard } = useListAllBoardsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
@ -208,11 +182,14 @@ const ImageGalleryContent = () => {
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const setScrollerRef = useCallback((ref: HTMLElement | Window | null) => {
|
||||
if (ref instanceof HTMLElement) {
|
||||
setScroller(ref);
|
||||
}
|
||||
}, []);
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
@ -314,29 +291,11 @@ const ImageGalleryContent = () => {
|
||||
withReset
|
||||
handleReset={() => dispatch(setGalleryImageMinimumWidth(64))}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('gallery.maintainAspectRatio')}
|
||||
isChecked={galleryImageObjectFit === 'contain'}
|
||||
onChange={() =>
|
||||
dispatch(
|
||||
setGalleryImageObjectFit(
|
||||
galleryImageObjectFit === 'contain' ? 'cover' : 'contain'
|
||||
)
|
||||
)
|
||||
}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('gallery.autoSwitchNewImages')}
|
||||
isChecked={shouldAutoSwitchToNewImages}
|
||||
isChecked={shouldAutoSwitch}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldAutoSwitchToNewImages(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<IAISimpleCheckbox
|
||||
label={t('gallery.singleColumnLayout')}
|
||||
isChecked={shouldUseSingleGalleryColumn}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldUseSingleGalleryColumn(e.target.checked))
|
||||
dispatch(shouldAutoSwitchChanged(e.target.checked))
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
@ -358,41 +317,28 @@ const ImageGalleryContent = () => {
|
||||
{images.length || areMoreAvailable ? (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
||||
{shouldUseSingleGalleryColumn ? (
|
||||
<Virtuoso
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, item) => (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
<HoverableImage
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
<VirtuosoGrid
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
components={{
|
||||
Item: ItemContainer,
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, item) => (
|
||||
<HoverableImage
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
<VirtuosoGrid
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
components={{
|
||||
Item: ItemContainer,
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, item) =>
|
||||
typeof item === 'string' ? (
|
||||
<Skeleton
|
||||
sx={{ w: 'full', h: 'full', aspectRatio: '1/1' }}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
) : (
|
||||
<GalleryImage
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
imageDTO={item}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
</Box>
|
||||
<IAIButton
|
||||
onClick={handleLoadMoreImages}
|
||||
@ -407,27 +353,10 @@ const ImageGalleryContent = () => {
|
||||
</IAIButton>
|
||||
</>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 2,
|
||||
padding: 8,
|
||||
h: '100%',
|
||||
w: '100%',
|
||||
color: 'base.500',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
as={MdPhotoLibrary}
|
||||
sx={{
|
||||
w: 16,
|
||||
h: 16,
|
||||
}}
|
||||
/>
|
||||
<Text textAlign="center">{t('gallery.noImagesInGallery')}</Text>
|
||||
</Flex>
|
||||
<IAINoContentFallback
|
||||
label={t('gallery.noImagesInGallery')}
|
||||
icon={FaImage}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</VStack>
|
||||
@ -436,7 +365,7 @@ const ImageGalleryContent = () => {
|
||||
|
||||
type ItemContainerProps = PropsWithChildren & FlexProps;
|
||||
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
|
||||
<Box className="item-container" ref={ref}>
|
||||
<Box className="item-container" ref={ref} p={1.5}>
|
||||
{props.children}
|
||||
</Box>
|
||||
));
|
||||
@ -453,8 +382,7 @@ const ListContainer = forwardRef((props: ListContainerProps, ref) => {
|
||||
className="list-container"
|
||||
ref={ref}
|
||||
sx={{
|
||||
gap: 2,
|
||||
gridTemplateColumns: `repeat(auto-fit, minmax(${galleryImageMinimumWidth}px, 1fr));`,
|
||||
gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
|
@ -5,14 +5,13 @@ import { clamp, isEqual } from 'lodash-es';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaAngleLeft, FaAngleRight } from 'react-icons/fa';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { imageSelected } from '../store/gallerySlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import {
|
||||
selectFilteredImagesAsObject,
|
||||
selectFilteredImagesIds,
|
||||
} from '../store/imagesSlice';
|
||||
imageSelected,
|
||||
selectImagesById,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { selectFilteredImages } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
|
||||
height: '100%',
|
||||
@ -25,45 +24,40 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
||||
};
|
||||
|
||||
export const nextPrevImageButtonsSelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state,
|
||||
gallerySelector,
|
||||
selectFilteredImagesAsObject,
|
||||
selectFilteredImagesIds,
|
||||
],
|
||||
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
|
||||
const { selectedImage } = gallery;
|
||||
[stateSelector, selectFilteredImages],
|
||||
(state, filteredImages) => {
|
||||
const lastSelectedImage =
|
||||
state.gallery.selection[state.gallery.selection.length - 1];
|
||||
|
||||
if (!selectedImage) {
|
||||
if (!lastSelectedImage || filteredImages.length === 0) {
|
||||
return {
|
||||
isOnFirstImage: true,
|
||||
isOnLastImage: true,
|
||||
};
|
||||
}
|
||||
|
||||
const currentImageIndex = filteredImageIds.findIndex(
|
||||
(i) => i === selectedImage
|
||||
const currentImageIndex = filteredImages.findIndex(
|
||||
(i) => i.image_name === lastSelectedImage
|
||||
);
|
||||
|
||||
const nextImageIndex = clamp(
|
||||
currentImageIndex + 1,
|
||||
0,
|
||||
filteredImageIds.length - 1
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const prevImageIndex = clamp(
|
||||
currentImageIndex - 1,
|
||||
0,
|
||||
filteredImageIds.length - 1
|
||||
filteredImages.length - 1
|
||||
);
|
||||
|
||||
const nextImageId = filteredImageIds[nextImageIndex];
|
||||
const prevImageId = filteredImageIds[prevImageIndex];
|
||||
const nextImageId = filteredImages[nextImageIndex].image_name;
|
||||
const prevImageId = filteredImages[prevImageIndex].image_name;
|
||||
|
||||
const nextImage = filteredImagesAsObject[nextImageId];
|
||||
const prevImage = filteredImagesAsObject[prevImageId];
|
||||
const nextImage = selectImagesById(state, nextImageId);
|
||||
const prevImage = selectImagesById(state, prevImageId);
|
||||
|
||||
const imagesLength = filteredImageIds.length;
|
||||
const imagesLength = filteredImages.length;
|
||||
|
||||
return {
|
||||
isOnFirstImage: currentImageIndex === 0,
|
||||
@ -101,11 +95,11 @@ const NextPrevImageButtons = () => {
|
||||
}, []);
|
||||
|
||||
const handlePrevImage = useCallback(() => {
|
||||
dispatch(imageSelected(prevImageId));
|
||||
prevImageId && dispatch(imageSelected(prevImageId));
|
||||
}, [dispatch, prevImageId]);
|
||||
|
||||
const handleNextImage = useCallback(() => {
|
||||
dispatch(imageSelected(nextImageId));
|
||||
nextImageId && dispatch(imageSelected(nextImageId));
|
||||
}, [dispatch, nextImageId]);
|
||||
|
||||
useHotkeys(
|
||||
|
@ -1,40 +0,0 @@
|
||||
import { useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
export const SelectedItemOverlay = () => {
|
||||
const [accent400, accent500] = useToken('colors', [
|
||||
'accent.400',
|
||||
'accent.500',
|
||||
]);
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
boxShadow: `inset 0px 0px 0px 2px ${mode(
|
||||
accent400,
|
||||
accent500
|
||||
)(colorMode)}`,
|
||||
borderRadius: 'var(--invokeai-radii-base)',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
@ -1,18 +0,0 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectImagesEntities } from '../store/imagesSlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const useGetImageByName = () => {
|
||||
const images = useAppSelector(selectImagesEntities);
|
||||
return useCallback(
|
||||
(name: string | undefined) => {
|
||||
if (!name) {
|
||||
return;
|
||||
}
|
||||
return images[name];
|
||||
},
|
||||
[images]
|
||||
);
|
||||
};
|
||||
|
||||
export default useGetImageByName;
|
@ -1,15 +1,6 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageUsage } from 'app/contexts/DeleteImageContext';
|
||||
import { ImageDTO, BoardDTO } from 'services/api/types';
|
||||
|
||||
export type RequestedImageDeletionArg = {
|
||||
image: ImageDTO;
|
||||
imageUsage: ImageUsage;
|
||||
};
|
||||
|
||||
export const requestedImageDeletion = createAction<RequestedImageDeletionArg>(
|
||||
'gallery/requestedImageDeletion'
|
||||
);
|
||||
import { ImageUsage } from 'app/contexts/AddImageToBoardContext';
|
||||
import { BoardDTO } from 'services/api/types';
|
||||
|
||||
export type RequestedBoardImagesDeletionArg = {
|
||||
board: BoardDTO;
|
||||
|
@ -1,10 +1,8 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
|
||||
type BoardsState = {
|
||||
searchText: string;
|
||||
selectedBoardId?: string;
|
||||
updateBoardModalOpen: boolean;
|
||||
};
|
||||
|
||||
@ -17,9 +15,6 @@ const boardsSlice = createSlice({
|
||||
name: 'boards',
|
||||
initialState: initialBoardsState,
|
||||
reducers: {
|
||||
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.selectedBoardId = action.payload;
|
||||
},
|
||||
setBoardSearchText: (state, action: PayloadAction<string>) => {
|
||||
state.searchText = action.payload;
|
||||
},
|
||||
@ -27,19 +22,9 @@ const boardsSlice = createSlice({
|
||||
state.updateBoardModalOpen = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(
|
||||
boardsApi.endpoints.deleteBoard.matchFulfilled,
|
||||
(state, action) => {
|
||||
if (action.meta.arg.originalArgs === state.selectedBoardId) {
|
||||
state.selectedBoardId = undefined;
|
||||
}
|
||||
}
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
|
||||
export const { setBoardSearchText, setUpdateBoardModalOpen } =
|
||||
boardsSlice.actions;
|
||||
|
||||
export const boardsSelector = (state: RootState) => state.boards;
|
||||
|
@ -1,8 +1,15 @@
|
||||
import { GalleryState } from './gallerySlice';
|
||||
import { initialGalleryState } from './gallerySlice';
|
||||
|
||||
/**
|
||||
* Gallery slice persist denylist
|
||||
*/
|
||||
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
||||
'shouldAutoSwitchToNewImages',
|
||||
export const galleryPersistDenylist: (keyof typeof initialGalleryState)[] = [
|
||||
'selection',
|
||||
'entities',
|
||||
'ids',
|
||||
'isLoading',
|
||||
'limit',
|
||||
'offset',
|
||||
'selectedBoardId',
|
||||
'total',
|
||||
];
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user