diff --git a/.gitignore b/.gitignore index 7f3b1278df..e9918d4fb5 100644 --- a/.gitignore +++ b/.gitignore @@ -201,8 +201,6 @@ checkpoints # If it's a Mac .DS_Store -invokeai/frontend/web/dist/* - # Let the frontend manage its own gitignore !invokeai/frontend/web/* diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 48eed0e4b9..4e23a69d90 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -27,17 +27,13 @@ class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @models_router.get( - "/{base_model}/{model_type}", + "/", operation_id="list_models", responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: Optional[BaseModelType] = Path( - default=None, description="Base model" - ), - model_type: Optional[ModelType] = Path( - default=None, description="The type of model to get" - ), + base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) @@ -55,10 +51,10 @@ async def list_models( response_model = UpdateModelResponse, ) async def update_model( - base_model: BaseModelType = Path(default='sd-1', description="Base model"), - model_type: ModelType = Path(default='main', description="The type of model"), - model_name: str = Path(default=None, description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> UpdateModelResponse: """ Add Model """ try: diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1bf9353368..4c7314bd2b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING +from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, + get_type_hints) -from pydantic import BaseModel, Field +from pydantic import BaseConfig, BaseModel, Field if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - + return dict( + map( + lambda t: (get_args(get_type_hints(t)["type"])[0], t), + BaseInvocation.get_all_subclasses(), + ) + ) + @classmethod def get_output_type(cls): return signature(cls.invoke).return_annotation @@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel): def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - - #fmt: off + + # fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - #fmt: on + # fmt: on # TODO: figure out a better way to provide these hints @@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False): "model", "control", "image_collection", + "vae_model", + "lora_model", ], ] tags: List[str] title: str + class CustomisedSchemaExtra(TypedDict): ui: UIConfig -class InvocationConfig(BaseModel.Config): +class InvocationConfig(BaseConfig): """Customizes pydantic's BaseModel.Config class for use by Invocations. Provide `schema_extra` a `ui` dict to add hints for generated UIs. diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0421841e8a..4850b9670d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,28 +1,28 @@ -from typing import Literal, Optional, Union -from pydantic import BaseModel, Field -from contextlib import ExitStack import re +from contextlib import ExitStack +from typing import List, Literal, Optional, Union + import torch +from compel import Compel +from compel.prompt_parser import (Blend, Conjunction, + CrossAttentionControlSubstitute, + FlattenedPrompt, Fragment) +from pydantic import BaseModel, Field -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management.lora import ModelPatcher - -from compel import Compel -from compel.prompt_parser import ( - Blend, - CrossAttentionControlSubstitute, - FlattenedPrompt, - Fragment, Conjunction, -) +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .model import ClipField class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field( + default=None, description="The name of conditioning data") + class Config: schema_extra = {"required": ["conditioning_name"]} @@ -52,84 +52,92 @@ class CompelInvocation(BaseInvocation): "title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": { - "model": "model" + "model": "model" } }, } @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder: - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + def _lora_loader(): + for lora in self.clip.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - ) - except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - - compel = Compel( - tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) - - conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + except ModelNotFoundException: + # print(e) + #import traceback + #print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") - if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ + text_encoder_info as text_encoder: - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=ti_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: ) + conjunction = Compel.parse_prompt_string(self.prompt) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + + if context.services.configuration.log_tokenization: + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object( + prompt) + + # TODO: long prompt support + # if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count( + tokenizer, conjunction), + cross_attention_control_args=options.get( + "cross_attention_control", None),) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False -) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], + truncate_if_too_long=False) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( @@ -148,13 +156,13 @@ def get_max_token_count( ) else: return len( - get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) - ) + get_tokens_for_prompt_object( + tokenizer, prompt, truncate_if_too_long)) def get_tokens_for_prompt_object( tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> [str]: +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError( "Blend is not supported here - you need to get tokens for each of its .children" @@ -183,7 +191,7 @@ def log_tokenization_for_conjunction( ): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): - if len(c.prompts)>1: + if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix @@ -238,7 +246,8 @@ def log_tokenization_for_prompt_object( ) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a9576a2fe1..3e691c934e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,18 +4,17 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops - -from pydantic import BaseModel, Field, validator import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field, validator from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.image_util.seamless import configure_model_padding +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, @@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") + latents_name: Optional[str] = Field( + default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} + class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off @@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) SAMPLER_NAME_VALUES = Literal[ @@ -70,16 +72,19 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( + scheduler_name, SCHEDULER_MAP['ddim']) + orig_scheduler_info = context.services.model_manager.get_model( + **scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config - + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = {**scheduler_config, ** + scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: + self, context: InvocationContext, source_node_id: str, + intermediate_state: PipelineIntermediateState) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning_data( + self, context: InvocationContext, scheduler) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get( + self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get( + self.negative_conditioning.conditioning_name) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation): guidance_scale=self.cfg_scale, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( - threshold=0.0,#threshold, - warmup=0.2,#warmup, - h_symmetry_time_pct=None,#h_symmetry_time_pct, - v_symmetry_time_pct=None#v_symmetry_time_pct, + threshold=0.0, # threshold, + warmup=0.2, # warmup, + h_symmetry_time_pct=None, # h_symmetry_time_pct, + v_symmetry_time_pct=None # v_symmetry_time_pct, ), ) @@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation): scheduler, # for ddim scheduler - eta=0.0, #ddim_eta + eta=0.0, # ddim_eta # for ancestral and sde schedulers generator=torch.Generator(device=uc.device).manual_seed(0), ) return conditioning_data - def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + def create_pipeline( + self, unet, scheduler) -> StableDiffusionGeneratorPipeline: # TODO: - #configure_model_padding( + # configure_model_padding( # unet, # self.seamless, # self.seamless_axes, - #) + # ) class FakeVae: class FakeVaeConfig: def __init__(self): self.block_out_channels = [0] - + def __init__(self): self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... + vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, @@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation): requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", ) - + def prep_control_data( self, context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + # really only need model for dtype and device + model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], do_classifier_free_guidance: bool = True, @@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation): print("Using HF model subfolders") print(" control_name: ", control_name) print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained(control_name, - subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_name, subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to( + model.device) else: - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.services.images.get_pil_image( + control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation): dtype=control_model.dtype, control_mode=control_info.control_mode, ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - ) + control_item = ControlNetData( + model=control_model, image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet: + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) - + pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback, - ) + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) + class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + latents: Optional[LatentsField] = Field( + description="The latents to use as a base image") + strength: float = Field( + default=0.7, ge=0, le=1, + description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - with unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, @@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - + control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype - ) + latent, device=unet.device, dtype=latent.dtype) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback - ) + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + latents: Optional[LatentsField] = Field( + description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize + image = (image / 2 + 0.5).clamp(0, 1) # denormalize # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) -LATENTS_INTERPOLATION_MODE = Literal[ - "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" -] + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", + "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to resize") + width: int = Field( + ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field( + ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) resized_latents = torch.nn.functional.interpolate( - latents, - size=(self.height // 8, self.width // 8), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, size=(self.height // 8, self.width // 8), + mode=self.mode, antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to scale") + scale_factor: float = Field( + gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # resizing resized_latents = torch.nn.functional.interpolate( - latents, - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Union[ImageField, None] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e51873c59e..17297ba417 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,5 +1,5 @@ import copy -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -12,35 +12,42 @@ class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(description="Info to load submodel") + submodel: Optional[SubModelType] = Field( + default=None, description="Info to load submodel" + ) + class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") + class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class ClipField(BaseModel): tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") + class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output"] = "model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on class MainModelField(BaseModel): @@ -50,6 +57,13 @@ class MainModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") +class LoRAModelField(BaseModel): + """LoRA model field""" + + model_name: str = Field(description="Name of the LoRA model") + base_model: BaseModelType = Field(description="Base model") + + class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation): "ui": { "title": "Model Loader", "tags": ["model", "loader"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation): ) """ - return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.Vae, ), - ) + ), ) + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" unet: Optional[UNetField] = Field(default=None, description="UNet submodel") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - #fmt: on + # fmt: on + class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora_name: str = Field(description="Lora model name") + lora: Union[LoRAModelField, None] = Field( + default=None, description="Lora model name" + ) weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") @@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation): "ui": { "title": "Lora Loader", "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, }, } def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") - # TODO: ui rewrite - base_model = BaseModelType.StableDiffusion1 + base_model = self.lora.base_model + lora_name = self.lora.model_name if not context.services.model_manager.model_exists( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {self.lora_name}!") + raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + if self.unet is not None and any( + lora.model_name == lora_name for lora in self.unet.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + if self.clip is not None and any( + lora.model_name == lora_name for lora in self.clip.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() @@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -230,25 +251,29 @@ class LoraLoaderInvocation(BaseInvocation): return output + class VAEModelField(BaseModel): """Vae model field""" model_name: str = Field(description="Name of the model") base_model: BaseModelType = Field(description="Base model") + class VaeLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["vae_loader_output"] = "vae_loader_output" vae: VaeField = Field(default=None, description="Vae model") - #fmt: on + # fmt: on + class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" + type: Literal["vae_loader"] = "vae_loader" - + vae_model: VAEModelField = Field(description="The VAE to load") # Schema customisation @@ -257,29 +282,27 @@ class VaeLoaderInvocation(BaseInvocation): "ui": { "title": "VAE Loader", "tags": ["vae", "loader"], - "type_hints": { - "vae_model": "vae_model" - } + "type_hints": {"vae_model": "vae_model"}, }, } - + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, + base_model=base_model, + model_name=model_name, + model_type=model_type, ): raise Exception(f"Unkown vae name: {model_name}!") return VaeLoaderOutput( vae=VaeField( - vae = ModelInfo( - model_name = model_name, - base_model = base_model, - model_type = model_type, + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, ) ) ) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index e0f1ceeb25..e7f817fc0a 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings): upcase_environ = dict() for key,value in os.environ.items(): upcase_environ[key.upper()] = value - + fields = cls.__fields__ cls.argparse_groups = {} - + for name, field in fields.items(): if name not in cls._excluded(): current_default = field.default @@ -348,7 +348,7 @@ setting environment variables INVOKEAI_. ''' singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_init: ClassVar[Dict] = None - + #fmt: off type: Literal["InvokeAI"] = "InvokeAI" host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') @@ -367,7 +367,8 @@ setting environment variables INVOKEAI_. always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') @@ -385,9 +386,9 @@ setting environment variables INVOKEAI_. outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - + model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') - + log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") @@ -396,7 +397,7 @@ setting environment variables INVOKEAI_. def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False): ''' - Update settings with contents of init file, environment, and + Update settings with contents of init file, environment, and command-line settings. :param conf: alternate Omegaconf dictionary object :param argv: aternate sys.argv list @@ -411,7 +412,7 @@ setting environment variables INVOKEAI_. except: pass InvokeAISettings.initconf = conf - + # parse args again in order to pick up settings in configuration file super().parse_args(argv) @@ -431,7 +432,7 @@ setting environment variables INVOKEAI_. cls.singleton_config = cls(**kwargs) cls.singleton_init = kwargs return cls.singleton_config - + @property def root_path(self)->Path: ''' diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 300ae0fddd..eb2c014b1a 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -40,13 +40,13 @@ class ModelManagerServiceBase(ABC): logger: ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ pass - + @abstractmethod def get_model( self, @@ -57,8 +57,8 @@ class ModelManagerServiceBase(ABC): node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) + """Retrieve the indicated model with name and type. + submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" pass @@ -129,8 +129,8 @@ class ModelManagerServiceBase(ABC): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ pass @@ -161,8 +161,8 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ pass @@ -249,7 +249,7 @@ class ModelManagerService(ModelManagerServiceBase): logger: ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. @@ -279,6 +279,8 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(config,'max_cache_size') \ else config.max_loaded_models * 2.5 + logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + sequential_offload = config.sequential_guidance self.mgr = ModelManager( @@ -334,7 +336,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, model_info=model_info ) - + return model_info def model_exists( @@ -394,8 +396,8 @@ class ModelManagerService(ModelManagerServiceBase): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ self.logger.debug(f'add/update model {model_name}') @@ -427,8 +429,8 @@ class ModelManagerService(ModelManagerServiceBase): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ self.logger.debug(f'delete model {model_name}') @@ -503,7 +505,7 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): return self.mgr.logger - + def heuristic_import(self, items_to_import: set[str], prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, @@ -552,4 +554,3 @@ class ModelManagerService(ModelManagerServiceBase): interp = interp, force = force, ) - diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index c8e024f484..6f9cee6246 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -76,6 +76,10 @@ class MigrateTo3(object): Create a unique name for a model for use within models.yaml. ''' done = False + + # some model names have slashes in them, which really screws things up + name = name.replace('/','_') + key = ModelManager.create_key(name,info.base_type,info.model_type) unique_name = key counter = 1 @@ -219,11 +223,12 @@ class MigrateTo3(object): repo_id = 'openai/clip-vit-large-patch14' self._migrate_pretrained(CLIPTokenizer, repo_id= repo_id, - dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', + dest= target_dir / 'clip-vit-large-patch14', **kwargs) self._migrate_pretrained(CLIPTextModel, repo_id = repo_id, - dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', + dest = target_dir / 'clip-vit-large-patch14', + force = True, **kwargs) # sd-2 @@ -287,21 +292,21 @@ class MigrateTo3(object): def _model_probe_to_path(self, info: ModelProbeInfo)->Path: return Path(self.dest_models, info.base_type.value, info.model_type.value) - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): - if dest.exists(): + def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs): + if dest.exists() and not force: logger.info(f'Skipping existing {dest}') return model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest) + self._save_pretrained(model, dest, overwrite=force) - def _save_pretrained(self, model, dest: Path): - if dest.exists(): - logger.info(f'Skipping existing {dest}') - return + def _save_pretrained(self, model, dest: Path, overwrite: bool=False): model_name = dest.name - download_path = dest.with_name(f'{model_name}.downloading') - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) + if overwrite: + model.save_pretrained(dest, safe_serialization=True) + else: + download_path = dest.with_name(f'{model_name}.downloading') + model.save_pretrained(download_path, safe_serialization=True) + download_path.replace(dest) def _download_vae(self, repo_id: str, subfolder:str=None)->Path: vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) @@ -569,8 +574,10 @@ script, which will perform a full upgrade in place.""" dest_directory = args.dest_directory assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" - assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" - assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." + + # TODO: revisit + # assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" + # assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." do_migrate(root_directory,dest_directory) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 8bf716aeb7..29d61dee35 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -236,7 +236,6 @@ class ModelInstall(object): ) def _install_url(self, url: str)->AddModelResult: - # copy to a staging area, probe, import and delete with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) if not location: diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 1eeee92fb7..e3e64940de 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from .model_manager import ModelManager -from .model_cache import ModelCache +from picklescan.scanner import scan_file_path from .models import BaseModelType, ModelVariantType try: @@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint = load_file(checkpoint_path) else: if scan_needed: - ModelCache.scan_model(checkpoint_path, checkpoint_path) + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." checkpoint = torch.load(checkpoint_path) # sometimes there is a state_dict key and sometimes not diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index b92020189d..57868ca197 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,16 +1,17 @@ from __future__ import annotations import copy -from pathlib import Path from contextlib import contextmanager from typing import Optional, Dict, Tuple, Any, Union, List -import torch -from safetensors.torch import load_file +from pathlib import Path +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers.models import UNet2DConditionModel +from safetensors.torch import load_file from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer - -from compel.embeddings_provider import BaseTextualInversionManager +from torch.utils.hooks import RemovableHandle class LoRALayerBase: #rank: Optional[int] @@ -537,9 +538,10 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - with torch.autocast(device_type="cpu"): - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + #with torch.autocast(device_type="cpu"): + layer.to(dtype=torch.float32) + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris @@ -653,6 +655,9 @@ class TextualInversionModel: else: result.embedding = next(iter(state_dict.values())) + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + if not isinstance(result.embedding, torch.Tensor): raise ValueError(f"Invalid embeddings file: {file_path.name}") diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 77b6ac5115..df5a2f9272 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -100,8 +100,6 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' #max_cache_size = 9999 - execution_device = torch.device('cuda') - self.model_infos: Dict[str, ModelBase] = dict() self.lazy_offloading = lazy_offloading #self.sequential_offload: bool=sequential_offload diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index e279a0fed2..25081f83b4 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, - ModelConfigBase, + ModelConfigBase, ModelNotFoundException, ) # We are only starting to number the config file with release 3. @@ -409,7 +409,7 @@ class ModelManager(object): if model_key not in self.models: self.scan_models_directory(base_model=base_model, model_type=model_type) if model_key not in self.models: - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self.models[model_key] model_path = self.app_config.root_path / model_config.path @@ -421,7 +421,7 @@ class ModelManager(object): else: self.models.pop(model_key, None) - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") # vae/movq override # TODO: @@ -798,12 +798,12 @@ class ModelManager(object): if model_path.is_relative_to(self.app_config.root_path): model_path = model_path.relative_to(self.app_config.root_path) - try: - model_config: ModelConfigBase = model_class.probe_config(str(model_path)) - self.models[model_key] = model_config - new_models_found = True - except NotImplementedError as e: - self.logger.warning(e) + try: + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) + self.models[model_key] = model_config + new_models_found = True + except NotImplementedError as e: + self.logger.warning(e) imported_models = self.autoimport() diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 87b0ad3c4e..00630eef62 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -2,7 +2,7 @@ import inspect from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin -from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings +from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel from .lora import LoRAModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index afa62b2e4f..57c02bce76 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -15,6 +15,9 @@ from contextlib import suppress from pydantic import BaseModel, Field from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union +class ModelNotFoundException(Exception): + pass + class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 9a032218f0..4dcdbb24ba 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -8,6 +8,7 @@ from .base import ( ModelType, SubModelType, classproperty, + ModelNotFoundException, ) # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw @@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase): if child_type is not None: raise Exception("There is no child models in textual inversion") + checkpoint_path = self.model_path + if os.path.isdir(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") + + if not os.path.exists(checkpoint_path): + raise ModelNotFoundException() + model = TextualInversionModelRaw.from_checkpoint( - file_path=self.model_path, + file_path=checkpoint_path, dtype=torch_dtype, ) diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx index 5b6142d748..bf66c0ee08 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx @@ -1,4 +1,8 @@ import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { memo } from 'react'; import { TypesafeDraggableData } from './typesafeDnd'; @@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = { }, }; +const selector = createSelector( + stateSelector, + (state) => { + const gallerySelectionCount = state.gallery.selection.length; + const batchSelectionCount = state.batch.selection.length; + + return { + gallerySelectionCount, + batchSelectionCount, + }; + }, + defaultSelectorOptions +); + const DragPreview = (props: OverlayDragImageProps) => { + const { gallerySelectionCount, batchSelectionCount } = + useAppSelector(selector); + if (!props.dragData) { return; } @@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => { ); } - if (props.dragData.payloadType === 'IMAGE_NAMES') { + if (props.dragData.payloadType === 'BATCH_SELECTION') { return ( { ...STYLES, }} > - {props.dragData.payload.imageNames.length} + {batchSelectionCount} + Images + + ); + } + + if (props.dragData.payloadType === 'GALLERY_SELECTION') { + return ( + + {gallerySelectionCount} Images ); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx index e744a70750..1478ace748 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx @@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & { payload: { imageDTO: ImageDTO }; }; -export type ImageNamesDraggableData = BaseDragData & { - payloadType: 'IMAGE_NAMES'; - payload: { imageNames: string[] }; +export type GallerySelectionDraggableData = BaseDragData & { + payloadType: 'GALLERY_SELECTION'; +}; + +export type BatchSelectionDraggableData = BaseDragData & { + payloadType: 'BATCH_SELECTION'; }; export type TypesafeDraggableData = | ImageDraggableData - | ImageNamesDraggableData; + | GallerySelectionDraggableData + | BatchSelectionDraggableData; interface UseDroppableTypesafeArguments extends Omit { @@ -155,11 +159,13 @@ export const isValidDrop = ( case 'SET_NODES_IMAGE': return payloadType === 'IMAGE_DTO'; case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'MOVE_BOARD': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return ( + payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION' + ); default: return false; } diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index cb18d48301..ac1b9c5205 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -20,10 +20,8 @@ const serializationDenylist: { nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, - // config: configPersistDenyList, ui: uiPersistDenylist, controlNet: controlNetDenylist, - // hotkeys: hotkeysPersistDenylist, }; export const serialize: SerializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index ca20170c5d..f083a716a4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,21 +1,21 @@ -import { startAppListening } from '..'; -import { imageDeleted } from 'services/api/thunks/image'; import { log } from 'app/logging/useLogger'; -import { clamp } from 'lodash-es'; -import { - imageSelected, - imageRemoved, - selectImagesIds, -} from 'features/gallery/store/gallerySlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; -import { clearInitialImage } from 'features/parameters/store/generationSlice'; -import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; -import { api } from 'services/api'; +import { + imageRemoved, + imageSelected, + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; import { imageDeletionConfirmed, isModalOpenChanged, } from 'features/imageDeletion/store/imageDeletionSlice'; +import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { clamp } from 'lodash-es'; +import { api } from 'services/api'; +import { imageDeleted } from 'services/api/thunks/image'; +import { startAppListening } from '..'; const moduleLog = log.child({ namespace: 'image' }); @@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => { state.gallery.selection[state.gallery.selection.length - 1]; if (lastSelectedImage === image_name) { - const ids = selectImagesIds(state); + const filteredImages = selectFilteredImages(state); + + const ids = filteredImages.map((i) => i.image_name); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 56f660a653..24a5bffec7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -1,24 +1,23 @@ import { createAction } from '@reduxjs/toolkit'; -import { startAppListening } from '../'; -import { log } from 'app/logging/useLogger'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { log } from 'app/logging/useLogger'; import { imageAddedToBatch, imagesAddedToBatch, } from 'features/batch/store/batchSlice'; -import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldValueChanged, imageCollectionFieldValueChanged, } from 'features/nodes/store/nodesSlice'; -import { boardsApi } from 'services/api/endpoints/boards'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { boardImagesApi } from 'services/api/endpoints/boardImages'; +import { startAppListening } from '../'; const moduleLog = log.child({ namespace: 'dnd' }); @@ -33,6 +32,7 @@ export const addImageDroppedListener = () => { effect: (action, { dispatch, getState }) => { const { activeData, overData } = action.payload; const { actionType } = overData; + const state = getState(); // set current image if ( @@ -64,9 +64,9 @@ export const addImageDroppedListener = () => { // add multiple images to batch if ( actionType === 'ADD_TO_BATCH' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { - dispatch(imagesAddedToBatch(activeData.payload.imageNames)); + dispatch(imagesAddedToBatch(state.gallery.selection)); } // set control image @@ -128,14 +128,14 @@ export const addImageDroppedListener = () => { // set multiple nodes images (multiple images handler) if ( actionType === 'SET_MULTI_NODES_IMAGE' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { const { fieldName, nodeId } = overData.context; dispatch( imageCollectionFieldValueChanged({ nodeId, fieldName, - value: activeData.payload.imageNames.map((image_name) => ({ + value: state.gallery.selection.map((image_name) => ({ image_name, })), }) diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 2fd071bd23..5208933e7b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,31 +8,32 @@ import { import dynamicMiddlewares from 'redux-dynamic-middlewares'; import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import batchReducer from 'features/batch/store/batchSlice'; import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; +import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; +import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; +import loraReducer from 'features/lora/store/loraSlice'; +import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; -import systemReducer from 'features/system/store/systemSlice'; -import nodesReducer from 'features/nodes/store/nodesSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; +import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; -import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; -import batchReducer from 'features/batch/store/batchSlice'; -import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { actionsDenylist } from './middleware/devtools/actionsDenylist'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { api } from 'services/api'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { api } from 'services/api'; +import { actionSanitizer } from './middleware/devtools/actionSanitizer'; +import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; const allReducers = { canvas: canvasReducer, @@ -50,6 +51,7 @@ const allReducers = { dynamicPrompts: dynamicPromptsReducer, batch: batchReducer, imageDeletion: imageDeletionReducer, + lora: loraReducer, [api.reducerPath]: api.reducer, }; @@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'controlNet', 'dynamicPrompts', 'batch', + 'lora', // 'boards', // 'hotkeys', // 'config', diff --git a/invokeai/frontend/web/src/common/components/IAICollapse.tsx b/invokeai/frontend/web/src/common/components/IAICollapse.tsx index 5db26f3841..09dc1392e2 100644 --- a/invokeai/frontend/web/src/common/components/IAICollapse.tsx +++ b/invokeai/frontend/web/src/common/components/IAICollapse.tsx @@ -4,22 +4,25 @@ import { Collapse, Flex, Spacer, - Switch, + Text, useColorMode, + useDisclosure, } from '@chakra-ui/react'; +import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo } from 'react'; import { mode } from 'theme/util/mode'; export type IAIToggleCollapseProps = PropsWithChildren & { label: string; - isOpen: boolean; - onToggle: () => void; - withSwitch?: boolean; + activeLabel?: string; + defaultIsOpen?: boolean; }; const IAICollapse = (props: IAIToggleCollapseProps) => { - const { label, isOpen, onToggle, children, withSwitch = false } = props; + const { label, activeLabel, children, defaultIsOpen = false } = props; + const { isOpen, onToggle } = useDisclosure({ defaultIsOpen }); const { colorMode } = useColorMode(); + return ( { alignItems: 'center', p: 2, px: 4, + gap: 2, borderTopRadius: 'base', borderBottomRadius: isOpen ? 0 : 'base', bg: isOpen @@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { }} > {label} + + {activeLabel && ( + + + {activeLabel} + + + )} + - {withSwitch && } - {!withSwitch && ( - - )} + { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx index 9b023fd2d7..585dc106a8 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx @@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx index 54a3b30a4f..d25ab0d87e 100644 --- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx +++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx @@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => { isDisabled={isDisabled} width={width} display="flex" - gap={4} alignItems="center" {...formControlProps} > @@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => { sx={{ cursor: isDisabled ? 'not-allowed' : 'pointer', ...formLabelProps?.sx, + pe: 4, }} {...formLabelProps} > diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx index 822b1cf183..4a6250f93a 100644 --- a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx +++ b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx @@ -1,28 +1,29 @@ import { Box, Icon, Skeleton } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { FaExclamationCircle } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; import { batchImageRangeEndSelected, batchImageSelected, batchImageSelectionToggled, imageRemovedFromBatch, } from 'features/batch/store/batchSlice'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { createSelector } from '@reduxjs/toolkit'; -import { RootState, stateSelector } from 'app/store/store'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { FaExclamationCircle } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -const isSelectedSelector = createSelector( - [stateSelector, (state: RootState, imageName: string) => imageName], - (state, imageName) => ({ - selection: state.batch.selection, - isSelected: state.batch.selection.includes(imageName), - }), - defaultSelectorOptions -); +const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + (state) => ({ + selectionCount: state.batch.selection.length, + isSelected: state.batch.selection.includes(image_name), + }), + defaultSelectorOptions + ); type BatchImageProps = { imageName: string; @@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => { } = useGetImageDTOQuery(props.imageName); const dispatch = useAppDispatch(); - const { isSelected, selection } = useAppSelector((state) => - isSelectedSelector(state, props.imageName) + const selector = useMemo( + () => makeSelector(props.imageName), + [props.imageName] ); + const { isSelected, selectionCount } = useAppSelector(selector); + const handleClickRemove = useCallback(() => { dispatch(imageRemovedFromBatch(props.imageName)); }, [dispatch, props.imageName]); @@ -59,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'batch', - payloadType: 'IMAGE_NAMES', - payload: { - imageNames: selection, - }, + payloadType: 'BATCH_SELECTION', }; } @@ -76,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); if (isError) { return ; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index df73f1141d..dde449a464 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -1,25 +1,22 @@ -import { memo, useCallback, useMemo, useState } from 'react'; -import { ImageDTO } from 'services/api/types'; -import { - ControlNetConfig, - controlNetImageChanged, - controlNetSelector, -} from '../store/controlNetSlice'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; -import IAIDndImage from 'common/components/IAIDndImage'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { FaUndo } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; +import { + ControlNetConfig, + controlNetImageChanged, + controlNetSelector, +} from '../store/controlNetSlice'; const selector = createSelector( controlNetSelector, @@ -83,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => { } }, [controlImage, controlNetId]); - const droppableData = useMemo(() => { - if (controlNetId) { - return { - id: controlNetId, - actionType: 'SET_CONTROLNET_IMAGE', - context: { controlNetId }, - }; - } - }, [controlNetId]); + const droppableData = useMemo( + () => ({ + id: controlNetId, + actionType: 'SET_CONTROLNET_IMAGE', + context: { controlNetId }, + }), + [controlNetId] + ); const postUploadAction = useMemo( () => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx new file mode 100644 index 0000000000..3a7eea2fbf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx @@ -0,0 +1,36 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISwitch from 'common/components/IAISwitch'; +import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice'; +import { useCallback } from 'react'; + +const selector = createSelector( + stateSelector, + (state) => { + const { isEnabled } = state.controlNet; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamControlNetFeatureToggle = () => { + const { isEnabled } = useAppSelector(selector); + const dispatch = useAppDispatch(); + + const handleChange = useCallback(() => { + dispatch(isControlNetEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamControlNetFeatureToggle; diff --git a/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts new file mode 100644 index 0000000000..4bff39db63 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts @@ -0,0 +1,15 @@ +import { filter } from 'lodash-es'; +import { ControlNetConfig } from '../store/controlNetSlice'; + +export const getValidControlNets = ( + controlNets: Record +) => { + const validControlNets = filter( + controlNets, + (c) => + c.isEnabled && + (Boolean(c.processedControlImage) || + (c.processorType === 'none' && Boolean(c.controlImage))) + ); + return validControlNets; +}; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx index 1aefecf3e6..0e41fad994 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx @@ -1,40 +1,30 @@ +import { Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { useCallback } from 'react'; -import { isEnabledToggled } from '../store/slice'; -import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; -import { Flex } from '@chakra-ui/react'; +import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled'; +import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; const selector = createSelector( stateSelector, (state) => { const { isEnabled } = state.dynamicPrompts; - return { isEnabled }; + return { activeLabel: isEnabled ? 'Enabled' : undefined }; }, defaultSelectorOptions ); const ParamDynamicPromptsCollapse = () => { - const dispatch = useAppDispatch(); - const { isEnabled } = useAppSelector(selector); - - const handleToggleIsEnabled = useCallback(() => { - dispatch(isEnabledToggled()); - }, [dispatch]); + const { activeLabel } = useAppSelector(selector); return ( - + + diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx index 30c2240c37..cb930acd3b 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx @@ -1,23 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { combinatorialToggled } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; +import { useCallback } from 'react'; +import { combinatorialToggled } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { combinatorial } = state.dynamicPrompts; + const { combinatorial, isEnabled } = state.dynamicPrompts; - return { combinatorial }; + return { combinatorial, isDisabled: !isEnabled }; }, defaultSelectorOptions ); const ParamDynamicPromptsCombinatorial = () => { - const { combinatorial } = useAppSelector(selector); + const { combinatorial, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback(() => { @@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => { return ( { + const { isEnabled } = state.dynamicPrompts; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamDynamicPromptsToggle = () => { + const dispatch = useAppDispatch(); + const { isEnabled } = useAppSelector(selector); + + const handleToggleIsEnabled = useCallback(() => { + dispatch(isEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamDynamicPromptsToggle; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index 19f02ae3e5..172120fd1e 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -1,25 +1,31 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { useCallback } from 'react'; +import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { maxPrompts, combinatorial } = state.dynamicPrompts; + const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts; const { min, sliderMax, inputMax } = state.config.sd.dynamicPrompts.maxPrompts; - return { maxPrompts, min, sliderMax, inputMax, combinatorial }; + return { + maxPrompts, + min, + sliderMax, + inputMax, + isDisabled: !isEnabled || !combinatorial, + }; }, defaultSelectorOptions ); const ParamDynamicPromptsMaxPrompts = () => { - const { maxPrompts, min, sliderMax, inputMax, combinatorial } = + const { maxPrompts, min, sliderMax, inputMax, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); @@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => { return ( image_name], - ({ gallery }, image_name) => { - const isSelected = gallery.selection.includes(image_name); - const selection = gallery.selection; - return { - isSelected, - selection, - }; - }, - defaultSelectorOptions -); +export const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + ({ gallery }) => { + const isSelected = gallery.selection.includes(image_name); + const selectionCount = gallery.selection.length; + return { + isSelected, + selectionCount, + }; + }, + defaultSelectorOptions + ); interface HoverableImageProps { imageDTO: ImageDTO; @@ -38,13 +39,13 @@ interface HoverableImageProps { * Gallery image component with delete/use all/use seed buttons on hover. */ const GalleryImage = (props: HoverableImageProps) => { - const { isSelected, selection } = useAppSelector((state) => - selector(state, props.imageDTO) - ); - const { imageDTO } = props; const { image_url, thumbnail_url, image_name } = imageDTO; + const localSelector = useMemo(() => makeSelector(image_name), [image_name]); + + const { isSelected, selectionCount } = useAppSelector(localSelector); + const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -74,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'gallery-image', - payloadType: 'IMAGE_NAMES', - payload: { imageNames: selection }, + payloadType: 'GALLERY_SELECTION', }; } @@ -89,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); return ( diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index f4d2babf38..41a52e3452 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -7,7 +7,6 @@ import { import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { dateComparator } from 'common/util/dateComparator'; -import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice'; import { keyBy, uniq } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { @@ -174,11 +173,6 @@ export const gallerySlice = createSlice({ state.limit = limit; state.total = total; }); - builder.addCase(imageDeletionConfirmed, (state, action) => { - // Image deleted - const { image_name } = action.payload.imageDTO; - imagesAdapter.removeOne(state, image_name); - }); builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { const { image_name, image_url, thumbnail_url } = action.payload; diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx index cdc8257488..8306437cc7 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx +++ b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx @@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store'; import { imageDeletionConfirmed, imageToDeleteCleared, + isModalOpenChanged, selectImageUsage, } from '../store/imageDeletionSlice'; @@ -63,6 +64,7 @@ const DeleteImageModal = () => { const handleClose = useCallback(() => { dispatch(imageToDeleteCleared()); + dispatch(isModalOpenChanged(false)); }, [dispatch]); const handleDelete = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts index 0daffba0d7..49630bcdb4 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts +++ b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts @@ -31,6 +31,7 @@ const imageDeletion = createSlice({ }, imageToDeleteCleared: (state) => { state.imageToDelete = null; + state.isModalOpen = false; }, }, }); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx new file mode 100644 index 0000000000..23459e9410 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -0,0 +1,59 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { FaTrash } from 'react-icons/fa'; +import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; + +type Props = { + lora: Lora; +}; + +const ParamLora = (props: Props) => { + const dispatch = useAppDispatch(); + const { lora } = props; + + const handleChange = useCallback( + (v: number) => { + dispatch(loraWeightChanged({ id: lora.id, weight: v })); + }, + [dispatch, lora.id] + ); + + const handleReset = useCallback(() => { + dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); + }, [dispatch, lora.id]); + + const handleRemoveLora = useCallback(() => { + dispatch(loraRemoved(lora.id)); + }, [dispatch, lora.id]); + + return ( + + + } + colorScheme="error" + /> + + ); +}; + +export default memo(ParamLora); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx new file mode 100644 index 0000000000..6e69f036df --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -0,0 +1,36 @@ +import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import { size } from 'lodash-es'; +import { memo } from 'react'; +import ParamLoraList from './ParamLoraList'; +import ParamLoraSelect from './ParamLoraSelect'; + +const selector = createSelector( + stateSelector, + (state) => { + const loraCount = size(state.lora.loras); + return { + activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined, + }; + }, + defaultSelectorOptions +); + +const ParamLoraCollapse = () => { + const { activeLabel } = useAppSelector(selector); + + return ( + + + + + + + ); +}; + +export default memo(ParamLoraCollapse); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx new file mode 100644 index 0000000000..89432ac862 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -0,0 +1,24 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { map } from 'lodash-es'; +import ParamLora from './ParamLora'; + +const selector = createSelector( + stateSelector, + ({ lora }) => { + const { loras } = lora; + + return { loras }; + }, + defaultSelectorOptions +); + +const ParamLoraList = () => { + const { loras } = useAppSelector(selector); + + return map(loras, (lora) => ); +}; + +export default ParamLoraList; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx new file mode 100644 index 0000000000..54ac3d615d --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -0,0 +1,107 @@ +import { Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { forwardRef, useCallback, useMemo } from 'react'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { loraAdded } from '../store/loraSlice'; + +type LoraSelectItem = { + label: string; + value: string; + description?: string; +}; + +const selector = createSelector( + stateSelector, + ({ lora }) => ({ + loras: lora.loras, + }), + defaultSelectorOptions +); + +const ParamLoraSelect = () => { + const dispatch = useAppDispatch(); + const { loras } = useAppSelector(selector); + const { data: lorasQueryData } = useGetLoRAModelsQuery(); + + const data = useMemo(() => { + if (!lorasQueryData) { + return []; + } + + const data: LoraSelectItem[] = []; + + forEach(lorasQueryData.entities, (lora, id) => { + if (!lora || Boolean(id in loras)) { + return; + } + + data.push({ + value: id, + label: lora.name, + description: lora.description, + }); + }); + + return data; + }, [loras, lorasQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + const loraEntity = lorasQueryData?.entities[v[0]]; + if (!loraEntity) { + return; + } + v[0] && dispatch(loraAdded(loraEntity)); + }, + [dispatch, lorasQueryData?.entities] + ); + + return ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + ); +}; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; + +export default ParamLoraSelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts new file mode 100644 index 0000000000..c9b290eb2d --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -0,0 +1,46 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; + +export type Lora = { + id: string; + name: string; + weight: number; +}; + +export const defaultLoRAConfig: Omit = { + weight: 1, +}; + +export type LoraState = { + loras: Record; +}; + +export const intialLoraState: LoraState = { + loras: {}, +}; + +export const loraSlice = createSlice({ + name: 'lora', + initialState: intialLoraState, + reducers: { + loraAdded: (state, action: PayloadAction) => { + const { name, id } = action.payload; + state.loras[id] = { id, name, ...defaultLoRAConfig }; + }, + loraRemoved: (state, action: PayloadAction) => { + const id = action.payload; + delete state.loras[id]; + }, + loraWeightChanged: ( + state, + action: PayloadAction<{ id: string; weight: number }> + ) => { + const { id, weight } = action.payload; + state.loras[id].weight = weight; + }, + }, +}); + +export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; + +export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 062fec2fdc..9925a48381 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'lora_model' && template.type === 'lora_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( @@ -34,23 +32,6 @@ const ImageInputFieldComponent = ( isSuccess, } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); - const handleDrop = useCallback( - ({ image_name }: ImageDTO) => { - if (field.value?.image_name === image_name) { - return; - } - - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: { image_name }, - }) - ); - }, - [dispatch, field.name, field.value, nodeId] - ); - const handleReset = useCallback(() => { dispatch( fieldValueChanged({ @@ -71,15 +52,14 @@ const ImageInputFieldComponent = ( } }, [field.name, imageDTO, nodeId]); - const droppableData = useMemo(() => { - if (imageDTO) { - return { - id: `node-${nodeId}-${field.name}`, - actionType: 'SET_NODES_IMAGE', - context: { nodeId, fieldName: field.name }, - }; - } - }, [field.name, imageDTO, nodeId]); + const droppableData = useMemo( + () => ({ + id: `node-${nodeId}-${field.name}`, + actionType: 'SET_NODES_IMAGE', + context: { nodeId, fieldName: field.name }, + }), + [field.name, nodeId] + ); const postUploadAction = useMemo( () => ({ diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx new file mode 100644 index 0000000000..02cdfd454d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -0,0 +1,102 @@ +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + VaeModelInputFieldTemplate, + VaeModelInputFieldValue, +} from 'features/nodes/types/types'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { forEach, isString } from 'lodash-es'; +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; + +const LoRAModelInputFieldComponent = ( + props: FieldComponentProps< + VaeModelInputFieldValue, + VaeModelInputFieldTemplate + > +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: loraModels } = useGetLoRAModelsQuery(); + + const selectedModel = useMemo( + () => loraModels?.entities[field.value ?? loraModels.ids[0]], + [loraModels?.entities, loraModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!loraModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(loraModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [loraModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && loraModels?.ids.includes(field.value)) { + return; + } + + const firstLora = loraModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, loraModels?.ids]); + + return ( + + ); +}; + +export default memo(LoRAModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index b5bb9c5b74..ee739e1002 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -11,7 +11,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component import { forEach, isString } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const ModelInputFieldComponent = ( @@ -22,9 +22,7 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx index 74d9942c84..b4408e41b2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx @@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component import { forEach } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const VaeModelInputFieldComponent = ( @@ -24,9 +24,7 @@ const VaeModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: vaeModels } = useListModelsQuery({ - model_type: 'vae', - }); + const { data: vaeModels } = useGetVaeModelsQuery(); const selectedModel = useMemo( () => vaeModels?.entities[field.value ?? vaeModels.ids[0]], diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index ffc93db2ba..4fa69c626b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,8 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; +import { RgbaColor } from 'react-colorful'; import { addEdge, applyEdgeChanges, @@ -11,12 +14,9 @@ import { NodeChange, OnConnectStartParams, } from 'reactflow'; -import { ImageField } from 'services/api/types'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { ImageField } from 'services/api/types'; import { InvocationTemplate, InvocationValue } from '../types/types'; -import { RgbaColor } from 'react-colorful'; -import { RootState } from 'app/store/store'; -import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es'; export type NodesState = { nodes: Node[]; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index b864501803..5fe780a286 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record = { VaeField: 'vae', model: 'model', vae_model: 'vae_model', + lora_model: 'lora_model', array: 'array', item: 'item', ColorField: 'color', @@ -120,7 +121,13 @@ export const FIELDS: Record = { vae_model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), - title: 'Model', + title: 'VAE', + description: 'Models are models.', + }, + lora_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'LoRA', description: 'Models are models.', }, array: { diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index c7e573ace2..3de8cae9ff 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -65,6 +65,7 @@ export type FieldType = | 'control' | 'model' | 'vae_model' + | 'lora_model' | 'array' | 'item' | 'color' @@ -93,6 +94,7 @@ export type InputFieldValue = | EnumInputFieldValue | ModelInputFieldValue | VaeModelInputFieldValue + | LoRAModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -119,6 +121,7 @@ export type InputFieldTemplate = | EnumInputFieldTemplate | ModelInputFieldTemplate | VaeModelInputFieldTemplate + | LoRAModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & { value?: string; }; +export type LoRAModelInputFieldValue = FieldValueBase & { + type: 'lora_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { type: 'vae_model'; }; +export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'lora_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index 11ceb23763..5c4d67ebd3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { filter } from 'lodash-es'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { NonNullableGraph } from '../types/types'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; @@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = ( ): void => { const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; - const validControlNets = filter( - controlNets, - (c) => - c.isEnabled && - (Boolean(c.processedControlImage) || - (c.processorType === 'none' && Boolean(c.controlImage))) - ); + const validControlNets = getValidControlNets(controlNets); if (isControlNetEnabled && Boolean(validControlNets.length)) { if (validControlNets.length > 1) { diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index c71618175a..1c2dbc0c3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -18,6 +18,7 @@ import { IntegerInputFieldTemplate, ItemInputFieldTemplate, LatentsInputFieldTemplate, + LoRAModelInputFieldTemplate, ModelInputFieldTemplate, OutputFieldTemplate, StringInputFieldTemplate, @@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({ return template; }; +const buildLoRAModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { + const template: LoRAModelInputFieldTemplate = { + ...baseField, + type: 'lora_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -460,6 +476,9 @@ export const buildInputFieldTemplate = ( if (['vae_model'].includes(fieldType)) { return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); } + if (['lora_model'].includes(fieldType)) { + return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index a94d3ddef2..950038b691 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -79,6 +79,10 @@ export const buildInputFieldValue = ( if (template.type === 'vae_model') { fieldValue.value = undefined; } + + if (template.type === 'lora_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts new file mode 100644 index 0000000000..9712ef4d5f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -0,0 +1,148 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { LoraLoaderInvocation } from 'services/api/types'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { + LORA_LOADER, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, +} from './constants'; + +export const addLoRAsToGraph = ( + graph: NonNullableGraph, + state: RootState, + baseNodeId: string +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + + if (loraCount > 0) { + // remove any existing connections from main model loader, we need to insert the lora nodes + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === MAIN_MODEL_LOADER && + ['unet', 'clip'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { id, name, weight } = lora; + const loraField = modelIdToLoRAModelField(id); + const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( + '.', + '_' + )}`; + + const loraLoaderNode: LoraLoaderInvocation = { + type: 'lora_loader', + id: currentLoraNodeId, + lora: loraField, + weight, + }; + + graph.nodes[currentLoraNodeId] = loraLoaderNode; + + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 5cf9882ac1..1843efef84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -9,6 +9,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_TO_IMAGE_GRAPH, @@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 82912de219..c4f9415067 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -8,6 +8,7 @@ import { RangeOfSizeInvocation, } from 'services/api/types'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { INPAINT, @@ -194,6 +195,8 @@ export const buildCanvasInpaintGraph = ( ], }; + addLoRAsToGraph(graph, state, INPAINT); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index cfe5e62805..976ea4fd01 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 2e4383c3e7..fe6d1292e4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -10,6 +10,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_COLLECTION, @@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = ( }, }); } + + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index e0e71a00a2..04dccf4983 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add Custom VAE Support addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 3265a0f889..12a567b009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; @@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'lora_model') { + if (field.value) { + return modelIdToLoRAModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 58a7d0335b..7aace48def 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; export const MAIN_MODEL_LOADER = 'main_model_loader'; export const VAE_LOADER = 'vae_loader'; +export const LORA_LOADER = 'lora_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts new file mode 100644 index 0000000000..052b58484b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts @@ -0,0 +1,12 @@ +import { BaseModelType, LoRAModelField } from 'services/api/types'; + +export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const field: LoRAModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx index fea0d8330a..b9cc8511aa 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx @@ -1,20 +1,15 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; +import { Flex } from '@chakra-ui/react'; import IAICollapse from 'common/components/IAICollapse'; import { memo } from 'react'; -import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; +import { useTranslation } from 'react-i18next'; import ParamBoundingBoxHeight from './ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; const ParamBoundingBoxCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx index ed01da9876..a531eba57f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx @@ -1,4 +1,4 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; +import { Flex } from '@chakra-ui/react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse'; import ParamInfillMethod from './ParamInfillMethod'; import ParamInfillTilesize from './ParamInfillTilesize'; import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing'; -import ParamScaledWidth from './ParamScaledWidth'; import ParamScaledHeight from './ParamScaledHeight'; +import ParamScaledWidth from './ParamScaledWidth'; const ParamInfillCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx index 992e8b6d02..88d839fa15 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx @@ -1,22 +1,16 @@ +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import ParamSeamBlur from './ParamSeamBlur'; import ParamSeamSize from './ParamSeamSize'; import ParamSeamSteps from './ParamSeamSteps'; import ParamSeamStrength from './ParamSeamStrength'; -import { useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; const ParamSeamCorrectionCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 06c6108dcb..59bf7542eb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -1,41 +1,45 @@ import { Divider, Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { Fragment, memo, useCallback } from 'react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAICollapse from 'common/components/IAICollapse'; +import ControlNet from 'features/controlNet/components/ControlNet'; +import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { controlNetAdded, controlNetSelector, - isControlNetEnabledToggled, } from 'features/controlNet/store/controlNetSlice'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; -import { v4 as uuidv4 } from 'uuid'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import IAIButton from 'common/components/IAIButton'; -import ControlNet from 'features/controlNet/components/ControlNet'; +import { map } from 'lodash-es'; +import { Fragment, memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( controlNetSelector, (controlNet) => { const { controlNets, isEnabled } = controlNet; - return { controlNetsArray: map(controlNets), isEnabled }; + const validControlNets = getValidControlNets(controlNets); + + const activeLabel = + isEnabled && validControlNets.length > 0 + ? `${validControlNets.length} Active` + : undefined; + + return { controlNetsArray: map(controlNets), activeLabel }; }, defaultSelectorOptions ); const ParamControlNetCollapse = () => { const { t } = useTranslation(); - const { controlNetsArray, isEnabled } = useAppSelector(selector); + const { controlNetsArray, activeLabel } = useAppSelector(selector); const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const dispatch = useAppDispatch(); - const handleClickControlNetToggle = useCallback(() => { - dispatch(isControlNetEnabledToggled()); - }, [dispatch]); - const handleClickedAddControlNet = useCallback(() => { dispatch(controlNetAdded({ controlNetId: uuidv4() })); }, [dispatch]); @@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => { } return ( - + + {controlNetsArray.map((c, i) => ( {i > 0 && } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx index 111e3d3ae8..d32ff960d5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; @@ -27,7 +28,8 @@ const selector = createSelector( shouldUseSliders, shift, }; - } + }, + defaultSelectorOptions ); const ParamCFGScale = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx index 9501c8b475..6939ede424 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setHeight } from 'features/parameters/store/generationSlice'; @@ -25,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamHeightProps = Omit< diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx index a8cdabc8c9..1e203a1e45 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx @@ -1,37 +1,38 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setIterations } from 'features/parameters/store/generationSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { hotkeysSelector } from 'features/ui/store/hotkeysSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector([stateSelector], (state) => { - const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = - state.config.sd.iterations; - const { iterations } = state.generation; - const { shouldUseSliders } = state.ui; - const isDisabled = - state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; +const selector = createSelector( + [stateSelector], + (state) => { + const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = + state.config.sd.iterations; + const { iterations } = state.generation; + const { shouldUseSliders } = state.ui; + const isDisabled = + state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; - const step = state.hotkeys.shift ? fineStep : coarseStep; + const step = state.hotkeys.shift ? fineStep : coarseStep; - return { - iterations, - initial, - min, - sliderMax, - inputMax, - step, - shouldUseSliders, - isDisabled, - }; -}); + return { + iterations, + initial, + min, + sliderMax, + inputMax, + step, + shouldUseSliders, + isDisabled, + }; + }, + defaultSelectorOptions +); const ParamIterations = () => { const { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx index f43cdd425b..d939113c7c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; @@ -33,7 +34,8 @@ const selector = createSelector( step, shouldUseSliders, }; - } + }, + defaultSelectorOptions ); const ParamSteps = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx index b7d63038d1..b4121184b5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { IAIFullSliderProps } from 'common/components/IAISlider'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setWidth } from 'features/parameters/store/generationSlice'; import { configSelector } from 'features/system/store/configSelectors'; @@ -26,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamWidthProps = Omit; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx index b4b077ad6c..fa8606d610 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx @@ -1,37 +1,39 @@ import { Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; -import { ParamHiresStrength } from './ParamHiresStrength'; -import { setHiresFix } from 'features/parameters/store/postprocessingSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { ParamHiresStrength } from './ParamHiresStrength'; +import { ParamHiresToggle } from './ParamHiresToggle'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamHiresCollapse = () => { const { t } = useTranslation(); - const hiresFix = useAppSelector( - (state: RootState) => state.postprocessing.hiresFix - ); + const { activeLabel } = useAppSelector(selector); const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setHiresFix(!hiresFix)); - if (!isHiresEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx index 0fc600e9e8..f8e6f22aa4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx @@ -23,7 +23,6 @@ export const ParamHiresToggle = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx index adb76d8da0..4dea1dad4f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx @@ -1,27 +1,33 @@ -import { useTranslation } from 'react-i18next'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import ParamPerlinNoise from './ParamPerlinNoise'; -import ParamNoiseThreshold from './ParamNoiseThreshold'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamNoiseThreshold from './ParamNoiseThreshold'; +import { ParamNoiseToggle } from './ParamNoiseToggle'; +import ParamPerlinNoise from './ParamPerlinNoise'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings } = state.generation; + return { + activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined, + }; + }, + defaultSelectorOptions +); const ParamNoiseCollapse = () => { const { t } = useTranslation(); const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; - const shouldUseNoiseSettings = useAppSelector( - (state: RootState) => state.generation.shouldUseNoiseSettings - ); - - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings)); + const { activeLabel } = useAppSelector(selector); if (!isNoiseEnabled) { return null; @@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => { return ( + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx index e339734992..3abb7532b4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx @@ -1,18 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setThreshold } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, threshold } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + threshold, + }; + }, + defaultSelectorOptions +); + export default function ParamNoiseThreshold() { const dispatch = useAppDispatch(); - const threshold = useAppSelector( - (state: RootState) => state.generation.threshold - ); + const { threshold, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + const dispatch = useAppDispatch(); + + const shouldUseNoiseSettings = useAppSelector( + (state: RootState) => state.generation.shouldUseNoiseSettings + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldUseNoiseSettings(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx index ad710eae54..afd676223c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx @@ -1,16 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setPerlin } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, perlin } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + perlin, + }; + }, + defaultSelectorOptions +); + export default function ParamPerlinNoise() { const dispatch = useAppDispatch(); - const perlin = useAppSelector((state: RootState) => state.generation.perlin); + const { perlin, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + if (seamlessXAxis && seamlessYAxis) { + return 'X & Y'; + } + + if (seamlessXAxis) { + return 'X'; + } + + if (seamlessYAxis) { + return 'Y'; + } +}; const selector = createSelector( generationSelector, (generation) => { - const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation; + const { seamlessXAxis, seamlessYAxis } = generation; - return { shouldUseSeamless, seamlessXAxis, seamlessYAxis }; + const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis); + return { activeLabel }; }, defaultSelectorOptions ); const ParamSeamlessCollapse = () => { const { t } = useTranslation(); - const { shouldUseSeamless } = useAppSelector(selector); + const { activeLabel } = useAppSelector(selector); const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless)); - if (!isSeamlessEnabled) { return null; } @@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx index 59bdb39be1..f2ddd19768 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx @@ -1,39 +1,39 @@ -import { memo } from 'react'; import { Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import ParamSymmetryHorizontal from './ParamSymmetryHorizontal'; import ParamSymmetryVertical from './ParamSymmetryVertical'; -import { useTranslation } from 'react-i18next'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { useTranslation } from 'react-i18next'; +import ParamSymmetryToggle from './ParamSymmetryToggle'; + +const selector = createSelector( + stateSelector, + (state) => ({ + activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined, + }), + defaultSelectorOptions +); const ParamSymmetryCollapse = () => { const { t } = useTranslation(); - const shouldUseSymmetry = useAppSelector( - (state: RootState) => state.generation.shouldUseSymmetry - ); + const { activeLabel } = useAppSelector(selector); const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry)); - if (!isSymmetryEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx index 7cc17c045e..59386ff526 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx @@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() { return ( dispatch(setShouldUseSymmetry(e.target.checked))} /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx index 1564bd64e5..3cdfc3a06b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx @@ -1,39 +1,42 @@ -import ParamVariationWeights from './ParamVariationWeights'; -import ParamVariationAmount from './ParamVariationAmount'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamVariationAmount from './ParamVariationAmount'; +import { ParamVariationToggle } from './ParamVariationToggle'; +import ParamVariationWeights from './ParamVariationWeights'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.generation.shouldGenerateVariations + ? 'Enabled' + : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamVariationCollapse = () => { const { t } = useTranslation(); - const shouldGenerateVariations = useAppSelector( - (state: RootState) => state.generation.shouldGenerateVariations - ); + const { activeLabel } = useAppSelector(selector); const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldGenerateVariations(!shouldGenerateVariations)); - if (!isVariationEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx new file mode 100644 index 0000000000..1c05468de0 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx @@ -0,0 +1,27 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; +import { ChangeEvent } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const ParamVariationToggle = () => { + const dispatch = useAppDispatch(); + + const shouldGenerateVariations = useAppSelector( + (state: RootState) => state.generation.shouldGenerateVariations + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldGenerateVariations(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 209cf4b639..960a41bb45 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -49,7 +49,6 @@ export interface GenerationState { verticalSymmetrySteps: number; model: ModelParam; vae: VAEParam; - shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; } @@ -84,9 +83,8 @@ export const initialGenerationState: GenerationState = { verticalSymmetrySteps: 0, model: '', vae: '', - shouldUseSeamless: false, - seamlessXAxis: true, - seamlessYAxis: true, + seamlessXAxis: false, + seamlessYAxis: false, }; const initialState: GenerationState = initialGenerationState; @@ -144,9 +142,6 @@ export const generationSlice = createSlice({ setImg2imgStrength: (state, action: PayloadAction) => { state.img2imgStrength = action.payload; }, - setSeamless: (state, action: PayloadAction) => { - state.shouldUseSeamless = action.payload; - }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -268,7 +263,6 @@ export const { modelSelected, vaeSelected, setShouldUseNoiseSettings, - setSeamless, setSeamlessXAxis, setSeamlessYAxis, } = generationSlice.actions; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 4232858621..4eeee3e4c6 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -8,7 +8,7 @@ import { modelSelected } from 'features/parameters/store/generationSlice'; import { SelectItem } from '@mantine/core'; import { RootState } from 'app/store/store'; import { forEach, isString } from 'lodash-es'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', @@ -23,9 +23,7 @@ const ModelSelect = () => { (state: RootState) => state.generation.model ); - const { data: mainModels, isLoading } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels, isLoading } = useGetMainModelsQuery(); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx index 19b508d30f..33901b5bef 100644 --- a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx @@ -6,7 +6,7 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { SelectItem } from '@mantine/core'; import { forEach } from 'lodash-es'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { RootState } from 'app/store/store'; import { vaeSelected } from 'features/parameters/store/generationSlice'; @@ -16,9 +16,7 @@ const VAESelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: vaeModels } = useListModelsQuery({ - model_type: 'vae', - }); + const { data: vaeModels } = useGetVaeModelsQuery(); const selectedModelId = useAppSelector( (state: RootState) => state.generation.vae diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 6986ded0a7..c618997f03 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -66,16 +66,16 @@ const tabs: InvokeTabInfo[] = [ icon: , content: , }, - // { - // id: 'batch', - // icon: , - // content: , - // }, { id: 'modelManager', icon: , content: , }, + // { + // id: 'batch', + // icon: , + // content: , + // }, ]; const enabledTabsSelector = createSelector( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index 89286232c6..5f5c7ad46b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -1,4 +1,4 @@ -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; @@ -21,19 +21,25 @@ const selector = createSelector( [uiSelector, generationSelector], (ui, generation) => { const { shouldUseSliders } = ui; - const { shouldFitToWidthHeight } = generation; + const { shouldFitToWidthHeight, shouldRandomizeSeed } = generation; - return { shouldUseSliders, shouldFitToWidthHeight }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, shouldFitToWidthHeight, activeLabel }; }, defaultSelectorOptions ); const ImageToImageTabCoreParameters = () => { - const { shouldUseSliders, shouldFitToWidthHeight } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, shouldFitToWidthHeight, activeLabel } = + useAppSelector(selector); return ( - + { return ( @@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 0cd90a9492..b71b5636b4 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -9,16 +9,14 @@ import IAISlider from 'common/components/IAISlider'; import { pickBy } from 'lodash-es'; import { useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export default function MergeModelsPanel() { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const { data } = useListModelsQuery({ - model_type: 'main', - }); + const { data } = useGetMainModelsQuery(); const diffusersModels = pickBy( data?.entities, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index 228fb79c2e..b22a303571 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -2,15 +2,13 @@ import { Flex } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; export default function ModelManagerPanel() { - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const openModel = useAppSelector( (state: RootState) => state.system.openModel diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index fac89b7edc..eb05e70357 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next'; import type { ChangeEvent, ReactNode } from 'react'; import React, { useMemo, useState, useTransition } from 'react'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; function ModelFilterButton({ label, @@ -36,9 +36,7 @@ function ModelFilterButton({ } const ModelList = () => { - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const [renderModelList, setRenderModelList] = React.useState(false); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx index 75d54667e9..9211e095ba 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx @@ -1,5 +1,6 @@ -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; @@ -11,25 +12,30 @@ import ParamScheduler from 'features/parameters/components/Parameters/Core/Param import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const TextToImageTabCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { return ( @@ -18,6 +19,7 @@ const TextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index 9226973101..330cd8b31e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -1,5 +1,6 @@ -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; @@ -12,25 +13,30 @@ import ParamScheduler from 'features/parameters/components/Parameters/Core/Param import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const UnifiedCanvasCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { return ( @@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/store/tabMap.ts b/invokeai/frontend/web/src/features/ui/store/tabMap.ts index 7c85805fe7..0cae8eac43 100644 --- a/invokeai/frontend/web/src/features/ui/store/tabMap.ts +++ b/invokeai/frontend/web/src/features/ui/store/tabMap.ts @@ -1,13 +1,10 @@ export const tabMap = [ 'txt2img', 'img2img', - // 'generate', 'unifiedCanvas', 'nodes', - 'batch', - // 'postprocessing', - // 'training', 'modelManager', + 'batch', ] as const; export type InvokeTabName = (typeof tabMap)[number]; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 39e4e46d3b..a9a914f0f2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,37 +1,85 @@ -import { ModelsList } from 'services/api/types'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; -import { keyBy } from 'lodash-es'; +import { cloneDeep } from 'lodash-es'; +import { + AnyModelConfig, + ControlNetModelConfig, + LoRAModelConfig, + MainModelConfig, + TextualInversionModelConfig, + VaeModelConfig, +} from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; -type ModelConfig = ModelsList['models'][number]; +export type MainModelConfigEntity = MainModelConfig & { id: string }; -type ListModelsArg = NonNullable< - paths['/api/v1/models/']['get']['parameters']['query'] ->; +export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; -const modelsAdapter = createEntityAdapter({ - selectId: (model) => getModelId(model), +export type ControlNetModelConfigEntity = ControlNetModelConfig & { + id: string; +}; + +export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { + id: string; +}; + +export type VaeModelConfigEntity = VaeModelConfig & { id: string }; + +type AnyModelConfigEntity = + | MainModelConfigEntity + | LoRAModelConfigEntity + | ControlNetModelConfigEntity + | TextualInversionModelConfigEntity + | VaeModelConfigEntity; + +const mainModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const loraModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const controlNetModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const textualInversionModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const vaeModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); -const getModelId = ({ base_model, type, name }: ModelConfig) => +export const getModelId = ({ base_model, type, name }: AnyModelConfig) => `${base_model}/${type}/${name}`; +const createModelEntities = ( + models: AnyModelConfig[] +): T[] => { + const entityArray: T[] = []; + models.forEach((model) => { + const entity = { + ...cloneDeep(model), + id: getModelId(model), + } as T; + entityArray.push(entity); + }); + return entityArray; +}; + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - listModels: build.query, ListModelsArg>({ - query: (arg) => ({ url: 'models/', params: arg }), + getMainModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'main' } }), providesTags: (result, error, arg) => { - // any list of boards - const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [ + { id: 'MainModel', type: LIST_TAG }, + ]; if (result) { - // and individual tags for each board tags.push( ...result.ids.map((id) => ({ - type: 'Model' as const, + type: 'MainModel' as const, id, })) ); @@ -39,14 +87,161 @@ export const modelsApi = api.injectEndpoints({ return tags; }, - transformResponse: (response: ModelsList, meta, arg) => { - return modelsAdapter.setAll( - modelsAdapter.getInitialState(), - keyBy(response.models, getModelId) + transformResponse: ( + response: { models: MainModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return mainModelsAdapter.setAll( + mainModelsAdapter.getInitialState(), + entities + ); + }, + }), + getLoRAModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'lora' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'LoRAModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'LoRAModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: LoRAModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return loraModelsAdapter.setAll( + loraModelsAdapter.getInitialState(), + entities + ); + }, + }), + getControlNetModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'ControlNetModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'ControlNetModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: ControlNetModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return controlNetModelsAdapter.setAll( + controlNetModelsAdapter.getInitialState(), + entities + ); + }, + }), + getVaeModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'vae' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'VaeModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'VaeModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: VaeModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return vaeModelsAdapter.setAll( + vaeModelsAdapter.getInitialState(), + entities + ); + }, + }), + getTextualInversionModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'TextualInversionModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'TextualInversionModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: TextualInversionModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return textualInversionModelsAdapter.setAll( + textualInversionModelsAdapter.getInitialState(), + entities ); }, }), }), }); -export const { useListModelsQuery } = modelsApi; +export const { + useGetMainModelsQuery, + useGetControlNetModelsQuery, + useGetLoRAModelsQuery, + useGetTextualInversionModelsQuery, + useGetVaeModelsQuery, +} = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 7dce36d6b3..d7e50d004e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -2690,6 +2690,19 @@ export type components = { model_format: components["schemas"]["LoRAModelFormat"]; error?: components["schemas"]["ModelError"]; }; + /** + * LoRAModelField + * @description LoRA model field + */ + LoRAModelField: { + /** + * Model Name + * @description Name of the LoRA model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** * LoRAModelFormat * @description An enumeration. @@ -2766,10 +2779,10 @@ export type components = { */ type?: "lora_loader"; /** - * Lora Name + * Lora * @description Lora model name */ - lora_name: string; + lora?: components["schemas"]["LoRAModelField"]; /** * Weight * @description With what weight to apply lora @@ -3115,7 +3128,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -4448,18 +4461,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 18942a47d6..3a0bdb71a7 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -4,91 +4,156 @@ import { components } from './schema'; type schemas = components['schemas']; /** - * Extracts the schema type from the schema. + * Marks the `type` property as required. Use for nodes. */ -type S = components['schemas'][T]; - -/** - * Extracts the node type from the schema. - * Also flags the `type` property as required. - */ -type N = O.Required< - components['schemas'][T], - 'type' ->; +type TypeReq = O.Required; // Images -export type ImageDTO = S<'ImageDTO'>; -export type BoardDTO = S<'BoardDTO'>; -export type BoardChanges = S<'BoardChanges'>; -export type ImageChanges = S<'ImageRecordChanges'>; -export type ImageCategory = S<'ImageCategory'>; -export type ResourceOrigin = S<'ResourceOrigin'>; -export type ImageField = S<'ImageField'>; +export type ImageDTO = components['schemas']['ImageDTO']; +export type BoardDTO = components['schemas']['BoardDTO']; +export type BoardChanges = components['schemas']['BoardChanges']; +export type ImageChanges = components['schemas']['ImageRecordChanges']; +export type ImageCategory = components['schemas']['ImageCategory']; +export type ResourceOrigin = components['schemas']['ResourceOrigin']; +export type ImageField = components['schemas']['ImageField']; export type OffsetPaginatedResults_BoardDTO_ = - S<'OffsetPaginatedResults_BoardDTO_'>; + components['schemas']['OffsetPaginatedResults_BoardDTO_']; export type OffsetPaginatedResults_ImageDTO_ = - S<'OffsetPaginatedResults_ImageDTO_'>; + components['schemas']['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = S<'ModelType'>; -export type BaseModelType = S<'BaseModelType'>; -export type MainModelField = S<'MainModelField'>; -export type VAEModelField = S<'VAEModelField'>; -export type ModelsList = S<'ModelsList'>; +export type ModelType = components['schemas']['ModelType']; +export type BaseModelType = components['schemas']['BaseModelType']; +export type MainModelField = components['schemas']['MainModelField']; +export type VAEModelField = components['schemas']['VAEModelField']; +export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ModelsList = components['schemas']['ModelsList']; + +// Model Configs +export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; +export type VaeModelConfig = components['schemas']['VaeModelConfig']; +export type ControlNetModelConfig = + components['schemas']['ControlNetModelConfig']; +export type TextualInversionModelConfig = + components['schemas']['TextualInversionModelConfig']; +export type MainModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type AnyModelConfig = + | LoRAModelConfig + | VaeModelConfig + | ControlNetModelConfig + | TextualInversionModelConfig + | MainModelConfig; // Graphs -export type Graph = S<'Graph'>; -export type Edge = S<'Edge'>; -export type GraphExecutionState = S<'GraphExecutionState'>; +export type Graph = components['schemas']['Graph']; +export type Edge = components['schemas']['Edge']; +export type GraphExecutionState = components['schemas']['GraphExecutionState']; // General nodes -export type CollectInvocation = N<'CollectInvocation'>; -export type IterateInvocation = N<'IterateInvocation'>; -export type RangeInvocation = N<'RangeInvocation'>; -export type RandomRangeInvocation = N<'RandomRangeInvocation'>; -export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>; -export type InpaintInvocation = N<'InpaintInvocation'>; -export type ImageResizeInvocation = N<'ImageResizeInvocation'>; -export type RandomIntInvocation = N<'RandomIntInvocation'>; -export type CompelInvocation = N<'CompelInvocation'>; -export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>; -export type NoiseInvocation = N<'NoiseInvocation'>; -export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>; -export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; -export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; -export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; -export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; -export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>; +export type CollectInvocation = TypeReq< + components['schemas']['CollectInvocation'] +>; +export type IterateInvocation = TypeReq< + components['schemas']['IterateInvocation'] +>; +export type RangeInvocation = TypeReq; +export type RandomRangeInvocation = TypeReq< + components['schemas']['RandomRangeInvocation'] +>; +export type RangeOfSizeInvocation = TypeReq< + components['schemas']['RangeOfSizeInvocation'] +>; +export type InpaintInvocation = TypeReq< + components['schemas']['InpaintInvocation'] +>; +export type ImageResizeInvocation = TypeReq< + components['schemas']['ImageResizeInvocation'] +>; +export type RandomIntInvocation = TypeReq< + components['schemas']['RandomIntInvocation'] +>; +export type CompelInvocation = TypeReq< + components['schemas']['CompelInvocation'] +>; +export type DynamicPromptInvocation = TypeReq< + components['schemas']['DynamicPromptInvocation'] +>; +export type NoiseInvocation = TypeReq; +export type TextToLatentsInvocation = TypeReq< + components['schemas']['TextToLatentsInvocation'] +>; +export type LatentsToLatentsInvocation = TypeReq< + components['schemas']['LatentsToLatentsInvocation'] +>; +export type ImageToLatentsInvocation = TypeReq< + components['schemas']['ImageToLatentsInvocation'] +>; +export type LatentsToImageInvocation = TypeReq< + components['schemas']['LatentsToImageInvocation'] +>; +export type ImageCollectionInvocation = TypeReq< + components['schemas']['ImageCollectionInvocation'] +>; +export type MainModelLoaderInvocation = TypeReq< + components['schemas']['MainModelLoaderInvocation'] +>; +export type LoraLoaderInvocation = TypeReq< + components['schemas']['LoraLoaderInvocation'] +>; // ControlNet Nodes -export type ControlNetInvocation = N<'ControlNetInvocation'>; -export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>; -export type ContentShuffleImageProcessorInvocation = - N<'ContentShuffleImageProcessorInvocation'>; -export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>; -export type LineartAnimeImageProcessorInvocation = - N<'LineartAnimeImageProcessorInvocation'>; -export type LineartImageProcessorInvocation = - N<'LineartImageProcessorInvocation'>; -export type MediapipeFaceProcessorInvocation = - N<'MediapipeFaceProcessorInvocation'>; -export type MidasDepthImageProcessorInvocation = - N<'MidasDepthImageProcessorInvocation'>; -export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>; -export type NormalbaeImageProcessorInvocation = - N<'NormalbaeImageProcessorInvocation'>; -export type OpenposeImageProcessorInvocation = - N<'OpenposeImageProcessorInvocation'>; -export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>; -export type ZoeDepthImageProcessorInvocation = - N<'ZoeDepthImageProcessorInvocation'>; +export type ControlNetInvocation = TypeReq< + components['schemas']['ControlNetInvocation'] +>; +export type CannyImageProcessorInvocation = TypeReq< + components['schemas']['CannyImageProcessorInvocation'] +>; +export type ContentShuffleImageProcessorInvocation = TypeReq< + components['schemas']['ContentShuffleImageProcessorInvocation'] +>; +export type HedImageProcessorInvocation = TypeReq< + components['schemas']['HedImageProcessorInvocation'] +>; +export type LineartAnimeImageProcessorInvocation = TypeReq< + components['schemas']['LineartAnimeImageProcessorInvocation'] +>; +export type LineartImageProcessorInvocation = TypeReq< + components['schemas']['LineartImageProcessorInvocation'] +>; +export type MediapipeFaceProcessorInvocation = TypeReq< + components['schemas']['MediapipeFaceProcessorInvocation'] +>; +export type MidasDepthImageProcessorInvocation = TypeReq< + components['schemas']['MidasDepthImageProcessorInvocation'] +>; +export type MlsdImageProcessorInvocation = TypeReq< + components['schemas']['MlsdImageProcessorInvocation'] +>; +export type NormalbaeImageProcessorInvocation = TypeReq< + components['schemas']['NormalbaeImageProcessorInvocation'] +>; +export type OpenposeImageProcessorInvocation = TypeReq< + components['schemas']['OpenposeImageProcessorInvocation'] +>; +export type PidiImageProcessorInvocation = TypeReq< + components['schemas']['PidiImageProcessorInvocation'] +>; +export type ZoeDepthImageProcessorInvocation = TypeReq< + components['schemas']['ZoeDepthImageProcessorInvocation'] +>; // Node Outputs -export type ImageOutput = S<'ImageOutput'>; -export type MaskOutput = S<'MaskOutput'>; -export type PromptOutput = S<'PromptOutput'>; -export type IterateInvocationOutput = S<'IterateInvocationOutput'>; -export type CollectInvocationOutput = S<'CollectInvocationOutput'>; -export type LatentsOutput = S<'LatentsOutput'>; -export type GraphInvocationOutput = S<'GraphInvocationOutput'>; +export type ImageOutput = components['schemas']['ImageOutput']; +export type MaskOutput = components['schemas']['MaskOutput']; +export type PromptOutput = components['schemas']['PromptOutput']; +export type IterateInvocationOutput = + components['schemas']['IterateInvocationOutput']; +export type CollectInvocationOutput = + components['schemas']['CollectInvocationOutput']; +export type LatentsOutput = components['schemas']['LatentsOutput']; +export type GraphInvocationOutput = + components['schemas']['GraphInvocationOutput']; diff --git a/invokeai/frontend/web/src/theme/components/textarea.ts b/invokeai/frontend/web/src/theme/components/textarea.ts index 85e6e37d3f..b737cf5e57 100644 --- a/invokeai/frontend/web/src/theme/components/textarea.ts +++ b/invokeai/frontend/web/src/theme/components/textarea.ts @@ -1,7 +1,28 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { getInputOutlineStyles } from '../util/getInputOutlineStyles'; -const invokeAI = defineStyle((props) => getInputOutlineStyles(props)); +const invokeAI = defineStyle((props) => ({ + ...getInputOutlineStyles(props), + '::-webkit-scrollbar': { + display: 'initial', + }, + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-50) 0%, + var(--invokeai-colors-base-50) 70%, + var(--invokeai-colors-base-200) 70%, + var(--invokeai-colors-base-200) 100%)`, + }, + _dark: { + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-900) 0%, + var(--invokeai-colors-base-900) 70%, + var(--invokeai-colors-base-800) 70%, + var(--invokeai-colors-base-800) 100%)`, + }, + }, +})); export const textareaTheme = defineStyleConfig({ variants: {