diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 0b03c8e729..dcbdbec82d 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,17 +2,17 @@ from typing import Literal, Optional, Union -from fastapi import Query +from fastapi import Query, Body from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType +from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) - class VaeRepo(BaseModel): repo_id: str = Field(description="The repo ID to use for this VAE") path: Optional[str] = Field(description="The path to the VAE") @@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel): info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") status: str = Field(description="The status of the API response") -class ImportModelRequest(BaseModel): - name: str = Field(description="A model path, repo_id or URL to import") - prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files') +class ImportModelResponse(BaseModel): + name: str = Field(description="The name of the imported model") +# base_model: str = Field(description="The base model") +# model_type: str = Field(description="The model type") + info: AddModelResult = Field(description="The model info") + status: str = Field(description="The status of the API response") class ConversionRequest(BaseModel): name: str = Field(description="The name of the new model") @@ -86,7 +89,6 @@ async def list_models( models = parse_obj_as(ModelsList, { "models": models_raw }) return models - @models_router.post( "/", operation_id="update_model", @@ -109,27 +111,38 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", - responses={200: {"status": "success"}}, + responses= { + 201: {"description" : "The model imported successfully"}, + 404: {"description" : "The model could not be found"}, + }, + status_code=201, + response_model=ImportModelResponse ) async def import_model( - model_request: ImportModelRequest -) -> None: - """ Add Model """ - items_to_import = set([model_request.name]) + name: str = Query(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), +) -> ImportModelResponse: + """ Add a model using its local path, repo_id, or remote URL """ + items_to_import = {name} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type) + prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - if len(installed_models) > 0: - logger.info(f'Successfully imported {model_request.name}') + if info := installed_models.get(name): + logger.info(f'Successfully imported {name}, got {info}') + return ImportModelResponse( + name = name, + info = info, + status = "success", + ) else: - logger.error(f'Model {model_request.name} not imported') - raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported') + logger.error(f'Model {name} not imported') + raise HTTPException(status_code=404, detail=f'Model {name} not found') @models_router.delete( "/{model_name}", diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1bf9353368..4c7314bd2b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING +from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, + get_type_hints) -from pydantic import BaseModel, Field +from pydantic import BaseConfig, BaseModel, Field if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - + return dict( + map( + lambda t: (get_args(get_type_hints(t)["type"])[0], t), + BaseInvocation.get_all_subclasses(), + ) + ) + @classmethod def get_output_type(cls): return signature(cls.invoke).return_annotation @@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel): def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - - #fmt: off + + # fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - #fmt: on + # fmt: on # TODO: figure out a better way to provide these hints @@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False): "model", "control", "image_collection", + "vae_model", + "lora_model", ], ] tags: List[str] title: str + class CustomisedSchemaExtra(TypedDict): ui: UIConfig -class InvocationConfig(BaseModel.Config): +class InvocationConfig(BaseConfig): """Customizes pydantic's BaseModel.Config class for use by Invocations. Provide `schema_extra` a `ui` dict to add hints for generated UIs. diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 333277540e..f5f387f7e7 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,27 +1,24 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional, Union, List from pydantic import BaseModel, Field import re import torch - -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - +from compel import Compel +from compel.prompt_parser import (Blend, Conjunction, + CrossAttentionControlSubstitute, + FlattenedPrompt, Fragment) from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.model_management import ModelType from ...backend.model_management.lora import ModelPatcher - -from compel import Compel -from compel.prompt_parser import ( - Blend, - CrossAttentionControlSubstitute, - FlattenedPrompt, - Fragment, Conjunction, -) +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .model import ClipField class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field( + default=None, description="The name of conditioning data") + class Config: schema_extra = {"required": ["conditioning_name"]} @@ -51,84 +48,92 @@ class CompelInvocation(BaseInvocation): "title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": { - "model": "model" + "model": "model" } }, } @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder: - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + def _lora_loader(): + for lora in self.clip.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - ) - except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - - compel = Compel( - tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) - - conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + except Exception: + # print(e) + #import traceback + # print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") - if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ + text_encoder_info as text_encoder: - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=ti_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: ) + conjunction = Compel.parse_prompt_string(self.prompt) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + + if context.services.configuration.log_tokenization: + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object( + prompt) + + # TODO: long prompt support + # if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count( + tokenizer, conjunction), + cross_attention_control_args=options.get( + "cross_attention_control", None),) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False -) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], + truncate_if_too_long=False) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( @@ -147,13 +152,13 @@ def get_max_token_count( ) else: return len( - get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) - ) + get_tokens_for_prompt_object( + tokenizer, prompt, truncate_if_too_long)) def get_tokens_for_prompt_object( tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> [str]: +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError( "Blend is not supported here - you need to get tokens for each of its .children" @@ -182,7 +187,7 @@ def log_tokenization_for_conjunction( ): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): - if len(c.prompts)>1: + if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix @@ -237,7 +242,8 @@ def log_tokenization_for_prompt_object( ) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 36b6b58df2..5bdeaa5a9c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,16 +3,16 @@ from typing import List, Literal, Optional, Union import einops - -from pydantic import BaseModel, Field, validator import torch from diffusers import ControlNetModel from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field, validator from invokeai.app.util.step_callback import stable_diffusion_step_callback from ..models.image import ImageCategory, ImageField, ResourceOrigin +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, @@ -21,7 +21,6 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -29,14 +28,17 @@ from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") + latents_name: Optional[str] = Field( + default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} + class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off @@ -50,11 +52,11 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) SAMPLER_NAME_VALUES = Literal[ @@ -67,16 +69,19 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( + scheduler_name, SCHEDULER_MAP['ddim']) + orig_scheduler_info = context.services.model_manager.get_model( + **scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config - + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = {**scheduler_config, ** + scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -121,18 +126,18 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: + self, context: InvocationContext, source_node_id: str, + intermediate_state: PipelineIntermediateState) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -140,9 +145,12 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning_data( + self, context: InvocationContext, scheduler) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get( + self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get( + self.negative_conditioning.conditioning_name) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -150,10 +158,10 @@ class TextToLatentsInvocation(BaseInvocation): guidance_scale=self.cfg_scale, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( - threshold=0.0,#threshold, - warmup=0.2,#warmup, - h_symmetry_time_pct=None,#h_symmetry_time_pct, - v_symmetry_time_pct=None#v_symmetry_time_pct, + threshold=0.0, # threshold, + warmup=0.2, # warmup, + h_symmetry_time_pct=None, # h_symmetry_time_pct, + v_symmetry_time_pct=None # v_symmetry_time_pct, ), ) @@ -161,31 +169,32 @@ class TextToLatentsInvocation(BaseInvocation): scheduler, # for ddim scheduler - eta=0.0, #ddim_eta + eta=0.0, # ddim_eta # for ancestral and sde schedulers generator=torch.Generator(device=uc.device).manual_seed(0), ) return conditioning_data - def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + def create_pipeline( + self, unet, scheduler) -> StableDiffusionGeneratorPipeline: # TODO: - #configure_model_padding( + # configure_model_padding( # unet, # self.seamless, # self.seamless_axes, - #) + # ) class FakeVae: class FakeVaeConfig: def __init__(self): self.block_out_channels = [0] - + def __init__(self): self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... + vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, @@ -195,11 +204,12 @@ class TextToLatentsInvocation(BaseInvocation): requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", ) - + def prep_control_data( self, context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + # really only need model for dtype and device + model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], do_classifier_free_guidance: bool = True, @@ -235,15 +245,17 @@ class TextToLatentsInvocation(BaseInvocation): print("Using HF model subfolders") print(" control_name: ", control_name) print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained(control_name, - subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_name, subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to( + model.device) else: - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.services.images.get_pil_image( + control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -260,41 +272,50 @@ class TextToLatentsInvocation(BaseInvocation): dtype=control_model.dtype, control_mode=control_info.control_mode, ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - ) + control_item = ControlNetData( + model=control_model, image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet: + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) - + pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -302,16 +323,15 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback, - ) + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -320,14 +340,18 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) + class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + latents: Optional[LatentsField] = Field( + description="The latents to use as a base image") + strength: float = Field( + default=0.7, ge=0, le=1, + description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -342,22 +366,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - with unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, @@ -367,7 +400,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - + control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -377,8 +410,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype - ) + latent, device=unet.device, dtype=latent.dtype) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -386,18 +418,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback - ) + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -414,9 +443,12 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + latents: Optional[LatentsField] = Field( + description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -447,7 +479,7 @@ class LatentsToImageInvocation(BaseInvocation): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize + image = (image / 2 + 0.5).clamp(0, 1) # denormalize # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -470,9 +502,9 @@ class LatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) -LATENTS_INTERPOLATION_MODE = Literal[ - "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" -] + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", + "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -481,21 +513,25 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to resize") + width: int = Field( + ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field( + ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) resized_latents = torch.nn.functional.interpolate( - latents, - size=(self.height // 8, self.width // 8), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, size=(self.height // 8, self.width // 8), + mode=self.mode, antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -512,21 +548,24 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to scale") + scale_factor: float = Field( + gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # resizing resized_latents = torch.nn.functional.interpolate( - latents, - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -545,7 +584,9 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Optional[ImageField] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 760fa08a12..17297ba417 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,31 +1,38 @@ -from typing import Literal, Optional, Union, List -from pydantic import BaseModel, Field import copy +from typing import List, Literal, Optional, Union -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from pydantic import BaseModel, Field -from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.model_management import BaseModelType, ModelType, SubModelType +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) + class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(description="Info to load submodel") + submodel: Optional[SubModelType] = Field( + default=None, description="Info to load submodel" + ) + class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") + class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class ClipField(BaseModel): tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") @@ -34,43 +41,48 @@ class VaeField(BaseModel): class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output"] = "model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on -class PipelineModelField(BaseModel): - """Pipeline model field""" +class MainModelField(BaseModel): + """Main model field""" model_name: str = Field(description="Name of the model") base_model: BaseModelType = Field(description="Base model") -class PipelineModelLoaderInvocation(BaseInvocation): - """Loads a pipeline model, outputting its submodels.""" +class LoRAModelField(BaseModel): + """LoRA model field""" - type: Literal["pipeline_model_loader"] = "pipeline_model_loader" + model_name: str = Field(description="Name of the LoRA model") + base_model: BaseModelType = Field(description="Base model") - model: PipelineModelField = Field(description="The model to load") + +class MainModelLoaderInvocation(BaseInvocation): + """Loads a main model, outputting its submodels.""" + + type: Literal["main_model_loader"] = "main_model_loader" + + model: MainModelField = Field(description="The model to load") # TODO: precision? # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { + "title": "Model Loader", "tags": ["model", "loader"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation): ) """ - return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.Vae, ), - ) + ), ) + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" unet: Optional[UNetField] = Field(default=None, description="UNet submodel") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - #fmt: on + # fmt: on + class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora_name: str = Field(description="Lora model name") + lora: Union[LoRAModelField, None] = Field( + default=None, description="Lora model name" + ) weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") clip: Optional[ClipField] = Field(description="Clip model for applying lora") - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "Lora Loader", + "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, + }, + } - # TODO: ui rewrite - base_model = BaseModelType.StableDiffusion1 + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") + + base_model = self.lora.base_model + lora_name = self.lora.model_name if not context.services.model_manager.model_exists( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {self.lora_name}!") + raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + if self.unet is not None and any( + lora.model_name == lora_name for lora in self.unet.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + if self.clip is not None and any( + lora.model_name == lora_name for lora in self.clip.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() @@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation): return output + +class VAEModelField(BaseModel): + """Vae model field""" + + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") + + +class VaeLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + # fmt: off + type: Literal["vae_loader_output"] = "vae_loader_output" + + vae: VaeField = Field(default=None, description="Vae model") + # fmt: on + + +class VaeLoaderInvocation(BaseInvocation): + """Loads a VAE model, outputting a VaeLoaderOutput""" + + type: Literal["vae_loader"] = "vae_loader" + + vae_model: VAEModelField = Field(description="The VAE to load") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "VAE Loader", + "tags": ["vae", "loader"], + "type_hints": {"vae_model": "vae_model"}, + }, + } + + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: + base_model = self.vae_model.base_model + model_name = self.vae_model.model_name + model_type = ModelType.Vae + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=model_name, + model_type=model_type, + ): + raise Exception(f"Unkown vae name: {model_name}!") + return VaeLoaderOutput( + vae=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + ) + ) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index e0f1ceeb25..e7f817fc0a 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings): upcase_environ = dict() for key,value in os.environ.items(): upcase_environ[key.upper()] = value - + fields = cls.__fields__ cls.argparse_groups = {} - + for name, field in fields.items(): if name not in cls._excluded(): current_default = field.default @@ -348,7 +348,7 @@ setting environment variables INVOKEAI_. ''' singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_init: ClassVar[Dict] = None - + #fmt: off type: Literal["InvokeAI"] = "InvokeAI" host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') @@ -367,7 +367,8 @@ setting environment variables INVOKEAI_. always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') @@ -385,9 +386,9 @@ setting environment variables INVOKEAI_. outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - + model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') - + log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") @@ -396,7 +397,7 @@ setting environment variables INVOKEAI_. def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False): ''' - Update settings with contents of init file, environment, and + Update settings with contents of init file, environment, and command-line settings. :param conf: alternate Omegaconf dictionary object :param argv: aternate sys.argv list @@ -411,7 +412,7 @@ setting environment variables INVOKEAI_. except: pass InvokeAISettings.initconf = conf - + # parse args again in order to pick up settings in configuration file super().parse_args(argv) @@ -431,7 +432,7 @@ setting environment variables INVOKEAI_. cls.singleton_config = cls(**kwargs) cls.singleton_init = kwargs return cls.singleton_config - + @property def root_path(self)->Path: ''' diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 10d1d91920..4e1da3b040 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.images import ImageServiceABC - from invokeai.backend import ModelManager + from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.restoration_services import RestorationServices @@ -22,46 +22,47 @@ class InvocationServices: """Services that can be used by invocations""" # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - events: "EventServiceBase" - latents: "LatentsStorageBase" - queue: "InvocationQueueABC" - model_manager: "ModelManager" - restoration: "RestorationServices" - configuration: "InvokeAISettings" - images: "ImageServiceABC" - boards: "BoardServiceABC" board_images: "BoardImagesServiceABC" - graph_library: "ItemStorageABC"["LibraryGraph"] + boards: "BoardServiceABC" + configuration: "InvokeAISettings" + events: "EventServiceBase" graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] + graph_library: "ItemStorageABC"["LibraryGraph"] + images: "ImageServiceABC" + latents: "LatentsStorageBase" + logger: "Logger" + model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + queue: "InvocationQueueABC" + restoration: "RestorationServices" def __init__( self, - model_manager: "ModelManager", - events: "EventServiceBase", - logger: "Logger", - latents: "LatentsStorageBase", - images: "ImageServiceABC", - boards: "BoardServiceABC", board_images: "BoardImagesServiceABC", - queue: "InvocationQueueABC", - graph_library: "ItemStorageABC"["LibraryGraph"], - graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], - processor: "InvocationProcessorABC", - restoration: "RestorationServices", + boards: "BoardServiceABC", configuration: "InvokeAISettings", + events: "EventServiceBase", + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], + graph_library: "ItemStorageABC"["LibraryGraph"], + images: "ImageServiceABC", + latents: "LatentsStorageBase", + logger: "Logger", + model_manager: "ModelManagerServiceBase", + processor: "InvocationProcessorABC", + queue: "InvocationQueueABC", + restoration: "RestorationServices", ): - self.model_manager = model_manager - self.events = events - self.logger = logger - self.latents = latents - self.images = images - self.boards = boards self.board_images = board_images - self.queue = queue - self.graph_library = graph_library - self.graph_execution_manager = graph_execution_manager - self.processor = processor - self.restoration = restoration - self.configuration = configuration self.boards = boards + self.boards = boards + self.configuration = configuration + self.events = events + self.graph_execution_manager = graph_execution_manager + self.graph_library = graph_library + self.images = images + self.latents = latents + self.logger = logger + self.model_manager = model_manager + self.processor = processor + self.queue = queue + self.restoration = restoration diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 9e5dbc1a40..b25136c240 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -33,13 +33,13 @@ class ModelManagerServiceBase(ABC): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ pass - + @abstractmethod def get_model( self, @@ -50,8 +50,8 @@ class ModelManagerServiceBase(ABC): node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) + """Retrieve the indicated model with name and type. + submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" pass @@ -115,8 +115,8 @@ class ModelManagerServiceBase(ABC): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ pass @@ -129,12 +129,35 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ pass + @abstractmethod + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + pass + @abstractmethod def commit(self, conf_file: Path = None) -> None: """ @@ -153,7 +176,7 @@ class ModelManagerService(ModelManagerServiceBase): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. @@ -185,6 +208,8 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(config,'max_cache_size') \ else config.max_loaded_models * 2.5 + logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + sequential_offload = config.sequential_guidance self.mgr = ModelManager( @@ -240,7 +265,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, model_info=model_info ) - + return model_info def model_exists( @@ -293,8 +318,8 @@ class ModelManagerService(ModelManagerServiceBase): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) @@ -307,8 +332,8 @@ class ModelManagerService(ModelManagerServiceBase): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ self.mgr.del_model(model_name, base_model, model_type) @@ -362,4 +387,25 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): return self.mgr.logger - + + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index fb3d964c7b..b32890f6b7 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -223,11 +223,11 @@ class MigrateTo3(object): repo_id = 'openai/clip-vit-large-patch14' self._migrate_pretrained(CLIPTokenizer, repo_id= repo_id, - dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', + dest= target_dir / 'clip-vit-large-patch14', **kwargs) self._migrate_pretrained(CLIPTextModel, repo_id = repo_id, - dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', + dest = target_dir / 'clip-vit-large-patch14', **kwargs) # sd-2 diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index ac06c402d4..00c19712a0 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -19,7 +19,7 @@ from tqdm import tqdm import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType +from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.util import download_with_resume from ..util.logging import InvokeAILogger @@ -173,18 +173,23 @@ class ModelInstall(object): # add requested models for path in selections.install_models: logger.info(f'Installing {path} [{job}/{jobs}]') - self.heuristic_install(path) + self.heuristic_import(path) job += 1 dlogging.set_verbosity(verbosity) self.mgr.commit() - def heuristic_install(self, + def heuristic_import(self, model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None)->Set[Path]: + models_installed: Set[Path]=None)->Dict[str, AddModelResult]: + ''' + :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL + :param models_installed: Set of installed models, used for recursive invocation + Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. + ''' if not models_installed: - models_installed = set() + models_installed = dict() # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url @@ -193,24 +198,24 @@ class ModelInstall(object): try: # checkpoint file, or similar if path.is_file(): - models_installed.add(self._install_path(path)) + models_installed.update(self._install_path(path)) # folders style or similar elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - models_installed.add(self._install_path(path)) + models_installed.update(self._install_path(path)) # recursive scan elif path.is_dir(): for child in path.iterdir(): - self.heuristic_install(child, models_installed=models_installed) + self.heuristic_import(child, models_installed=models_installed) # huggingface repo elif len(str(model_path_id_or_url).split('/')) == 2: - models_installed.add(self._install_repo(str(model_path_id_or_url))) + models_installed.update(self._install_repo(str(model_path_id_or_url))) # a URL elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): - models_installed.add(self._install_url(model_path_id_or_url)) + models_installed.update(self._install_url(model_path_id_or_url)) else: logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') @@ -222,24 +227,25 @@ class ModelInstall(object): # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: + def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: try: - # logger.debug(f'Probing {path}') + model_result = None info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) - model_name = path.stem if info.format=='checkpoint' else path.name + model_name = path.stem if path.is_file() else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') attributes = self._make_attributes(path,info) - self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) + model_result = self.mgr.add_model(model_name = model_name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes, + ) except Exception as e: logger.warning(f'{str(e)} Skipping registration.') - return path + return {} + return {str(path): model_result} - def _install_url(self, url: str)->Path: + def _install_url(self, url: str)->dict: # copy to a staging area, probe, import and delete with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) @@ -252,7 +258,7 @@ class ModelInstall(object): # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->Path: + def _install_repo(self, repo_id: str)->dict: hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index fb3b20a20a..34e0b15728 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,7 +1,7 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, ModelInfo +from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 6cfcb8dd8d..5d27555ab3 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,18 +1,17 @@ from __future__ import annotations import copy -from pathlib import Path from contextlib import contextmanager -from typing import Optional, Dict, Tuple, Any +from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file from torch.utils.hooks import RemovableHandle - -from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel -from compel.embeddings_provider import BaseTextualInversionManager class LoRALayerBase: #rank: Optional[int] @@ -539,9 +538,10 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - with torch.autocast(device_type="cpu"): - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + #with torch.autocast(device_type="cpu"): + layer.to(dtype=torch.float32) + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 30254aa060..82f738894d 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -233,14 +233,14 @@ import hashlib import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Optional, List, Tuple, Union, Set, Callable, types +from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree import torch from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel +from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig @@ -278,8 +278,13 @@ class InvalidModelError(Exception): "Raised when an invalid model is requested" pass -MAX_CACHE_SIZE = 6.0 # GB +class AddModelResult(BaseModel): + name: str = Field(description="The name of the model after import") + model_type: ModelType = Field(description="The type of model") + base_model: BaseModelType = Field(description="The base model") + config: ModelConfigBase = Field(description="The configuration of the model") +MAX_CACHE_SIZE = 6.0 # GB class ConfigMeta(BaseModel): version: str @@ -570,13 +575,16 @@ class ModelManager(object): model_type: ModelType, model_attributes: dict, clobber: bool = False, - ) -> None: + ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. On a successful update, the config will be changed in memory and the method will return True. Will fail with an assertion error if provided attributes are incorrect or the model name is missing. + + The returned dict has the same format as the dict returned by + model_info(). """ model_class = MODEL_CLASSES[base_model][model_type] @@ -600,12 +608,18 @@ class ModelManager(object): old_model_cache.unlink() # remove in-memory cache - # note: it not garantie to release memory(model can has other references) + # note: it not guaranteed to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) for cache_id in cache_ids: self.cache.uncache_model(cache_id) self.models[model_key] = model_config + return AddModelResult( + name = model_name, + model_type = model_type, + base_model = base_model, + config = model_config, + ) def search_models(self, search_folder): self.logger.info(f"Finding Models In: {search_folder}") @@ -717,19 +731,19 @@ class ModelManager(object): if model_path.is_relative_to(self.app_config.root_path): model_path = model_path.relative_to(self.app_config.root_path) - try: - model_config: ModelConfigBase = model_class.probe_config(str(model_path)) - self.models[model_key] = model_config - new_models_found = True - except NotImplementedError as e: - self.logger.warning(e) + try: + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) + self.models[model_key] = model_config + new_models_found = True + except NotImplementedError as e: + self.logger.warning(e) imported_models = self.autoimport() if (new_models_found or imported_models) and self.config_path: self.commit() - def autoimport(self)->set[Path]: + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. ''' @@ -742,7 +756,6 @@ class ModelManager(object): prediction_type_helper = ask_user_for_prediction_type, ) - installed = set() scanned_dirs = set() config = self.app_config @@ -756,13 +769,14 @@ class ModelManager(object): continue self.logger.info(f'Scanning {autodir} for models to import') + installed = dict() autodir = self.app_config.root_path / autodir if not autodir.exists(): continue items_scanned = 0 - new_models_found = set() + new_models_found = dict() for root, dirs, files in os.walk(autodir): items_scanned += len(dirs) + len(files) @@ -772,7 +786,7 @@ class ModelManager(object): scanned_dirs.add(path) continue if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - new_models_found.update(installer.heuristic_install(path)) + new_models_found.update(installer.heuristic_import(path)) scanned_dirs.add(path) for f in files: @@ -780,7 +794,7 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: continue if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - new_models_found.update(installer.heuristic_install(path)) + new_models_found.update(installer.heuristic_import(path)) self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') installed.update(new_models_found) @@ -790,7 +804,7 @@ class ModelManager(object): def heuristic_import(self, items_to_import: Set[str], prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Set[str]: + )->Dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -803,17 +817,20 @@ class ModelManager(object): generally impossible to do this programmatically, so the prediction_type_helper usually asks the user to choose. + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. ''' # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall - successfully_installed = set() + successfully_installed = dict() installer = ModelInstall(config = self.app_config, prediction_type_helper = prediction_type_helper, model_manager = self) for thing in items_to_import: try: - installed = installer.heuristic_install(thing) + installed = installer.heuristic_import(thing) successfully_installed.update(installed) except Exception as e: self.logger.warning(f'{thing} could not be imported: {str(e)}') diff --git a/invokeai/frontend/web/dist/index.html b/invokeai/frontend/web/dist/index.html index ed95c0c639..a0adc1d803 100644 --- a/invokeai/frontend/web/dist/index.html +++ b/invokeai/frontend/web/dist/index.html @@ -12,7 +12,7 @@ margin: 0; } - + diff --git a/invokeai/frontend/web/dist/locales/en.json b/invokeai/frontend/web/dist/locales/en.json index 1b3b790222..6fb56a2979 100644 --- a/invokeai/frontend/web/dist/locales/en.json +++ b/invokeai/frontend/web/dist/locales/en.json @@ -52,6 +52,7 @@ "unifiedCanvas": "Unified Canvas", "linear": "Linear", "nodes": "Node Editor", + "modelmanager": "Model Manager", "postprocessing": "Post Processing", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "postProcessing": "Post Processing", @@ -333,6 +334,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "vae": "VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", @@ -348,6 +350,7 @@ "scanForModels": "Scan For Models", "addManually": "Add Manually", "manual": "Manual", + "baseModel": "Base Model", "name": "Name", "nameValidationMsg": "Enter a name for your model", "description": "Description", @@ -360,6 +363,7 @@ "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", "vaeLocationValidationMsg": "Path to where your VAE is located.", + "variant": "Variant", "vaeRepoID": "VAE Repo ID", "vaeRepoIDValidationMsg": "Online repository of your VAE", "width": "Width", diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index d89f141e33..da5e15cb08 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -68,6 +68,7 @@ "@fontsource-variable/inter": "^5.0.3", "@fontsource/inter": "^5.0.3", "@mantine/core": "^6.0.14", + "@mantine/form": "^6.0.15", "@mantine/hooks": "^6.0.14", "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ab5d536f0c..9cf1e0bc48 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -53,6 +53,7 @@ "linear": "Linear", "nodes": "Node Editor", "batch": "Batch Manager", + "modelmanager": "Model Manager", "postprocessing": "Post Processing", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "postProcessing": "Post Processing", @@ -334,6 +335,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "vae": "VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", @@ -349,6 +351,7 @@ "scanForModels": "Scan For Models", "addManually": "Add Manually", "manual": "Manual", + "baseModel": "Base Model", "name": "Name", "nameValidationMsg": "Enter a name for your model", "description": "Description", @@ -361,6 +364,7 @@ "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", "vaeLocationValidationMsg": "Path to where your VAE is located.", + "variant": "Variant", "vaeRepoID": "VAE Repo ID", "vaeRepoIDValidationMsg": "Online repository of your VAE", "width": "Width", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 2b0e247d48..f43c8fc5c0 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { PartialAppConfig } from 'app/types/invokeai'; import ImageUploader from 'common/components/ImageUploader'; import GalleryDrawer from 'features/gallery/components/GalleryPanel'; +import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal'; import Lightbox from 'features/lightbox/components/Lightbox'; import SiteHeader from 'features/system/components/SiteHeader'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs'; import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import i18n from 'i18n'; import { ReactNode, memo, useEffect } from 'react'; +import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; -import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; -import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal'; -import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal'; const DEFAULT_CONFIG = {}; diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx index 5b6142d748..bf66c0ee08 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx @@ -1,4 +1,8 @@ import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { memo } from 'react'; import { TypesafeDraggableData } from './typesafeDnd'; @@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = { }, }; +const selector = createSelector( + stateSelector, + (state) => { + const gallerySelectionCount = state.gallery.selection.length; + const batchSelectionCount = state.batch.selection.length; + + return { + gallerySelectionCount, + batchSelectionCount, + }; + }, + defaultSelectorOptions +); + const DragPreview = (props: OverlayDragImageProps) => { + const { gallerySelectionCount, batchSelectionCount } = + useAppSelector(selector); + if (!props.dragData) { return; } @@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => { ); } - if (props.dragData.payloadType === 'IMAGE_NAMES') { + if (props.dragData.payloadType === 'BATCH_SELECTION') { return ( { ...STYLES, }} > - {props.dragData.payload.imageNames.length} + {batchSelectionCount} + Images + + ); + } + + if (props.dragData.payloadType === 'GALLERY_SELECTION') { + return ( + + {gallerySelectionCount} Images ); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx index e744a70750..1478ace748 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx @@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & { payload: { imageDTO: ImageDTO }; }; -export type ImageNamesDraggableData = BaseDragData & { - payloadType: 'IMAGE_NAMES'; - payload: { imageNames: string[] }; +export type GallerySelectionDraggableData = BaseDragData & { + payloadType: 'GALLERY_SELECTION'; +}; + +export type BatchSelectionDraggableData = BaseDragData & { + payloadType: 'BATCH_SELECTION'; }; export type TypesafeDraggableData = | ImageDraggableData - | ImageNamesDraggableData; + | GallerySelectionDraggableData + | BatchSelectionDraggableData; interface UseDroppableTypesafeArguments extends Omit { @@ -155,11 +159,13 @@ export const isValidDrop = ( case 'SET_NODES_IMAGE': return payloadType === 'IMAGE_DTO'; case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'MOVE_BOARD': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return ( + payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION' + ); default: return false; } diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index cb18d48301..ac1b9c5205 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -20,10 +20,8 @@ const serializationDenylist: { nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, - // config: configPersistDenyList, ui: uiPersistDenylist, controlNet: controlNetDenylist, - // hotkeys: hotkeysPersistDenylist, }; export const serialize: SerializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index ca20170c5d..f083a716a4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,21 +1,21 @@ -import { startAppListening } from '..'; -import { imageDeleted } from 'services/api/thunks/image'; import { log } from 'app/logging/useLogger'; -import { clamp } from 'lodash-es'; -import { - imageSelected, - imageRemoved, - selectImagesIds, -} from 'features/gallery/store/gallerySlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; -import { clearInitialImage } from 'features/parameters/store/generationSlice'; -import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; -import { api } from 'services/api'; +import { + imageRemoved, + imageSelected, + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; import { imageDeletionConfirmed, isModalOpenChanged, } from 'features/imageDeletion/store/imageDeletionSlice'; +import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { clamp } from 'lodash-es'; +import { api } from 'services/api'; +import { imageDeleted } from 'services/api/thunks/image'; +import { startAppListening } from '..'; const moduleLog = log.child({ namespace: 'image' }); @@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => { state.gallery.selection[state.gallery.selection.length - 1]; if (lastSelectedImage === image_name) { - const ids = selectImagesIds(state); + const filteredImages = selectFilteredImages(state); + + const ids = filteredImages.map((i) => i.image_name); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 56f660a653..24a5bffec7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -1,24 +1,23 @@ import { createAction } from '@reduxjs/toolkit'; -import { startAppListening } from '../'; -import { log } from 'app/logging/useLogger'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { log } from 'app/logging/useLogger'; import { imageAddedToBatch, imagesAddedToBatch, } from 'features/batch/store/batchSlice'; -import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldValueChanged, imageCollectionFieldValueChanged, } from 'features/nodes/store/nodesSlice'; -import { boardsApi } from 'services/api/endpoints/boards'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { boardImagesApi } from 'services/api/endpoints/boardImages'; +import { startAppListening } from '../'; const moduleLog = log.child({ namespace: 'dnd' }); @@ -33,6 +32,7 @@ export const addImageDroppedListener = () => { effect: (action, { dispatch, getState }) => { const { activeData, overData } = action.payload; const { actionType } = overData; + const state = getState(); // set current image if ( @@ -64,9 +64,9 @@ export const addImageDroppedListener = () => { // add multiple images to batch if ( actionType === 'ADD_TO_BATCH' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { - dispatch(imagesAddedToBatch(activeData.payload.imageNames)); + dispatch(imagesAddedToBatch(state.gallery.selection)); } // set control image @@ -128,14 +128,14 @@ export const addImageDroppedListener = () => { // set multiple nodes images (multiple images handler) if ( actionType === 'SET_MULTI_NODES_IMAGE' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { const { fieldName, nodeId } = overData.context; dispatch( imageCollectionFieldValueChanged({ nodeId, fieldName, - value: activeData.payload.imageNames.map((image_name) => ({ + value: state.gallery.selection.map((image_name) => ({ image_name, })), }) diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 2fd071bd23..5208933e7b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,31 +8,32 @@ import { import dynamicMiddlewares from 'redux-dynamic-middlewares'; import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import batchReducer from 'features/batch/store/batchSlice'; import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; +import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; +import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; +import loraReducer from 'features/lora/store/loraSlice'; +import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; -import systemReducer from 'features/system/store/systemSlice'; -import nodesReducer from 'features/nodes/store/nodesSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; +import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; -import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; -import batchReducer from 'features/batch/store/batchSlice'; -import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { actionsDenylist } from './middleware/devtools/actionsDenylist'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { api } from 'services/api'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { api } from 'services/api'; +import { actionSanitizer } from './middleware/devtools/actionSanitizer'; +import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; const allReducers = { canvas: canvasReducer, @@ -50,6 +51,7 @@ const allReducers = { dynamicPrompts: dynamicPromptsReducer, batch: batchReducer, imageDeletion: imageDeletionReducer, + lora: loraReducer, [api.reducerPath]: api.reducer, }; @@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'controlNet', 'dynamicPrompts', 'batch', + 'lora', // 'boards', // 'hotkeys', // 'config', diff --git a/invokeai/frontend/web/src/common/components/IAICollapse.tsx b/invokeai/frontend/web/src/common/components/IAICollapse.tsx index 5db26f3841..09dc1392e2 100644 --- a/invokeai/frontend/web/src/common/components/IAICollapse.tsx +++ b/invokeai/frontend/web/src/common/components/IAICollapse.tsx @@ -4,22 +4,25 @@ import { Collapse, Flex, Spacer, - Switch, + Text, useColorMode, + useDisclosure, } from '@chakra-ui/react'; +import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo } from 'react'; import { mode } from 'theme/util/mode'; export type IAIToggleCollapseProps = PropsWithChildren & { label: string; - isOpen: boolean; - onToggle: () => void; - withSwitch?: boolean; + activeLabel?: string; + defaultIsOpen?: boolean; }; const IAICollapse = (props: IAIToggleCollapseProps) => { - const { label, isOpen, onToggle, children, withSwitch = false } = props; + const { label, activeLabel, children, defaultIsOpen = false } = props; + const { isOpen, onToggle } = useDisclosure({ defaultIsOpen }); const { colorMode } = useColorMode(); + return ( { alignItems: 'center', p: 2, px: 4, + gap: 2, borderTopRadius: 'base', borderBottomRadius: isOpen ? 0 : 'base', bg: isOpen @@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { }} > {label} + + {activeLabel && ( + + + {activeLabel} + + + )} + - {withSwitch && } - {!withSwitch && ( - - )} + { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx index 9b023fd2d7..585dc106a8 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx @@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx index 54a3b30a4f..d25ab0d87e 100644 --- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx +++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx @@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => { isDisabled={isDisabled} width={width} display="flex" - gap={4} alignItems="center" {...formControlProps} > @@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => { sx={{ cursor: isDisabled ? 'not-allowed' : 'pointer', ...formLabelProps?.sx, + pe: 4, }} {...formLabelProps} > diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx index 822b1cf183..4a6250f93a 100644 --- a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx +++ b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx @@ -1,28 +1,29 @@ import { Box, Icon, Skeleton } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { FaExclamationCircle } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; import { batchImageRangeEndSelected, batchImageSelected, batchImageSelectionToggled, imageRemovedFromBatch, } from 'features/batch/store/batchSlice'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { createSelector } from '@reduxjs/toolkit'; -import { RootState, stateSelector } from 'app/store/store'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { FaExclamationCircle } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -const isSelectedSelector = createSelector( - [stateSelector, (state: RootState, imageName: string) => imageName], - (state, imageName) => ({ - selection: state.batch.selection, - isSelected: state.batch.selection.includes(imageName), - }), - defaultSelectorOptions -); +const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + (state) => ({ + selectionCount: state.batch.selection.length, + isSelected: state.batch.selection.includes(image_name), + }), + defaultSelectorOptions + ); type BatchImageProps = { imageName: string; @@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => { } = useGetImageDTOQuery(props.imageName); const dispatch = useAppDispatch(); - const { isSelected, selection } = useAppSelector((state) => - isSelectedSelector(state, props.imageName) + const selector = useMemo( + () => makeSelector(props.imageName), + [props.imageName] ); + const { isSelected, selectionCount } = useAppSelector(selector); + const handleClickRemove = useCallback(() => { dispatch(imageRemovedFromBatch(props.imageName)); }, [dispatch, props.imageName]); @@ -59,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'batch', - payloadType: 'IMAGE_NAMES', - payload: { - imageNames: selection, - }, + payloadType: 'BATCH_SELECTION', }; } @@ -76,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); if (isError) { return ; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index df73f1141d..dde449a464 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -1,25 +1,22 @@ -import { memo, useCallback, useMemo, useState } from 'react'; -import { ImageDTO } from 'services/api/types'; -import { - ControlNetConfig, - controlNetImageChanged, - controlNetSelector, -} from '../store/controlNetSlice'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; -import IAIDndImage from 'common/components/IAIDndImage'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { FaUndo } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; +import { + ControlNetConfig, + controlNetImageChanged, + controlNetSelector, +} from '../store/controlNetSlice'; const selector = createSelector( controlNetSelector, @@ -83,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => { } }, [controlImage, controlNetId]); - const droppableData = useMemo(() => { - if (controlNetId) { - return { - id: controlNetId, - actionType: 'SET_CONTROLNET_IMAGE', - context: { controlNetId }, - }; - } - }, [controlNetId]); + const droppableData = useMemo( + () => ({ + id: controlNetId, + actionType: 'SET_CONTROLNET_IMAGE', + context: { controlNetId }, + }), + [controlNetId] + ); const postUploadAction = useMemo( () => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx new file mode 100644 index 0000000000..3a7eea2fbf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx @@ -0,0 +1,36 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISwitch from 'common/components/IAISwitch'; +import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice'; +import { useCallback } from 'react'; + +const selector = createSelector( + stateSelector, + (state) => { + const { isEnabled } = state.controlNet; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamControlNetFeatureToggle = () => { + const { isEnabled } = useAppSelector(selector); + const dispatch = useAppDispatch(); + + const handleChange = useCallback(() => { + dispatch(isControlNetEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamControlNetFeatureToggle; diff --git a/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts new file mode 100644 index 0000000000..4bff39db63 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts @@ -0,0 +1,15 @@ +import { filter } from 'lodash-es'; +import { ControlNetConfig } from '../store/controlNetSlice'; + +export const getValidControlNets = ( + controlNets: Record +) => { + const validControlNets = filter( + controlNets, + (c) => + c.isEnabled && + (Boolean(c.processedControlImage) || + (c.processorType === 'none' && Boolean(c.controlImage))) + ); + return validControlNets; +}; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx index 1aefecf3e6..0e41fad994 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx @@ -1,40 +1,30 @@ +import { Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { useCallback } from 'react'; -import { isEnabledToggled } from '../store/slice'; -import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; -import { Flex } from '@chakra-ui/react'; +import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled'; +import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; const selector = createSelector( stateSelector, (state) => { const { isEnabled } = state.dynamicPrompts; - return { isEnabled }; + return { activeLabel: isEnabled ? 'Enabled' : undefined }; }, defaultSelectorOptions ); const ParamDynamicPromptsCollapse = () => { - const dispatch = useAppDispatch(); - const { isEnabled } = useAppSelector(selector); - - const handleToggleIsEnabled = useCallback(() => { - dispatch(isEnabledToggled()); - }, [dispatch]); + const { activeLabel } = useAppSelector(selector); return ( - + + diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx index 30c2240c37..cb930acd3b 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx @@ -1,23 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { combinatorialToggled } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; +import { useCallback } from 'react'; +import { combinatorialToggled } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { combinatorial } = state.dynamicPrompts; + const { combinatorial, isEnabled } = state.dynamicPrompts; - return { combinatorial }; + return { combinatorial, isDisabled: !isEnabled }; }, defaultSelectorOptions ); const ParamDynamicPromptsCombinatorial = () => { - const { combinatorial } = useAppSelector(selector); + const { combinatorial, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback(() => { @@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => { return ( { + const { isEnabled } = state.dynamicPrompts; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamDynamicPromptsToggle = () => { + const dispatch = useAppDispatch(); + const { isEnabled } = useAppSelector(selector); + + const handleToggleIsEnabled = useCallback(() => { + dispatch(isEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamDynamicPromptsToggle; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index 19f02ae3e5..172120fd1e 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -1,25 +1,31 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { useCallback } from 'react'; +import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { maxPrompts, combinatorial } = state.dynamicPrompts; + const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts; const { min, sliderMax, inputMax } = state.config.sd.dynamicPrompts.maxPrompts; - return { maxPrompts, min, sliderMax, inputMax, combinatorial }; + return { + maxPrompts, + min, + sliderMax, + inputMax, + isDisabled: !isEnabled || !combinatorial, + }; }, defaultSelectorOptions ); const ParamDynamicPromptsMaxPrompts = () => { - const { maxPrompts, min, sliderMax, inputMax, combinatorial } = + const { maxPrompts, min, sliderMax, inputMax, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); @@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => { return ( image_name], - ({ gallery }, image_name) => { - const isSelected = gallery.selection.includes(image_name); - const selection = gallery.selection; - return { - isSelected, - selection, - }; - }, - defaultSelectorOptions -); +export const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + ({ gallery }) => { + const isSelected = gallery.selection.includes(image_name); + const selectionCount = gallery.selection.length; + return { + isSelected, + selectionCount, + }; + }, + defaultSelectorOptions + ); interface HoverableImageProps { imageDTO: ImageDTO; @@ -38,13 +39,13 @@ interface HoverableImageProps { * Gallery image component with delete/use all/use seed buttons on hover. */ const GalleryImage = (props: HoverableImageProps) => { - const { isSelected, selection } = useAppSelector((state) => - selector(state, props.imageDTO) - ); - const { imageDTO } = props; const { image_url, thumbnail_url, image_name } = imageDTO; + const localSelector = useMemo(() => makeSelector(image_name), [image_name]); + + const { isSelected, selectionCount } = useAppSelector(localSelector); + const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -74,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'gallery-image', - payloadType: 'IMAGE_NAMES', - payload: { imageNames: selection }, + payloadType: 'GALLERY_SELECTION', }; } @@ -89,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); return ( diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index f4d2babf38..41a52e3452 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -7,7 +7,6 @@ import { import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { dateComparator } from 'common/util/dateComparator'; -import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice'; import { keyBy, uniq } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { @@ -174,11 +173,6 @@ export const gallerySlice = createSlice({ state.limit = limit; state.total = total; }); - builder.addCase(imageDeletionConfirmed, (state, action) => { - // Image deleted - const { image_name } = action.payload.imageDTO; - imagesAdapter.removeOne(state, image_name); - }); builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { const { image_name, image_url, thumbnail_url } = action.payload; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx new file mode 100644 index 0000000000..23459e9410 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -0,0 +1,59 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { FaTrash } from 'react-icons/fa'; +import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; + +type Props = { + lora: Lora; +}; + +const ParamLora = (props: Props) => { + const dispatch = useAppDispatch(); + const { lora } = props; + + const handleChange = useCallback( + (v: number) => { + dispatch(loraWeightChanged({ id: lora.id, weight: v })); + }, + [dispatch, lora.id] + ); + + const handleReset = useCallback(() => { + dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); + }, [dispatch, lora.id]); + + const handleRemoveLora = useCallback(() => { + dispatch(loraRemoved(lora.id)); + }, [dispatch, lora.id]); + + return ( + + + } + colorScheme="error" + /> + + ); +}; + +export default memo(ParamLora); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx new file mode 100644 index 0000000000..6e69f036df --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -0,0 +1,36 @@ +import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import { size } from 'lodash-es'; +import { memo } from 'react'; +import ParamLoraList from './ParamLoraList'; +import ParamLoraSelect from './ParamLoraSelect'; + +const selector = createSelector( + stateSelector, + (state) => { + const loraCount = size(state.lora.loras); + return { + activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined, + }; + }, + defaultSelectorOptions +); + +const ParamLoraCollapse = () => { + const { activeLabel } = useAppSelector(selector); + + return ( + + + + + + + ); +}; + +export default memo(ParamLoraCollapse); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx new file mode 100644 index 0000000000..89432ac862 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -0,0 +1,24 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { map } from 'lodash-es'; +import ParamLora from './ParamLora'; + +const selector = createSelector( + stateSelector, + ({ lora }) => { + const { loras } = lora; + + return { loras }; + }, + defaultSelectorOptions +); + +const ParamLoraList = () => { + const { loras } = useAppSelector(selector); + + return map(loras, (lora) => ); +}; + +export default ParamLoraList; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx new file mode 100644 index 0000000000..54ac3d615d --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -0,0 +1,107 @@ +import { Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { forwardRef, useCallback, useMemo } from 'react'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { loraAdded } from '../store/loraSlice'; + +type LoraSelectItem = { + label: string; + value: string; + description?: string; +}; + +const selector = createSelector( + stateSelector, + ({ lora }) => ({ + loras: lora.loras, + }), + defaultSelectorOptions +); + +const ParamLoraSelect = () => { + const dispatch = useAppDispatch(); + const { loras } = useAppSelector(selector); + const { data: lorasQueryData } = useGetLoRAModelsQuery(); + + const data = useMemo(() => { + if (!lorasQueryData) { + return []; + } + + const data: LoraSelectItem[] = []; + + forEach(lorasQueryData.entities, (lora, id) => { + if (!lora || Boolean(id in loras)) { + return; + } + + data.push({ + value: id, + label: lora.name, + description: lora.description, + }); + }); + + return data; + }, [loras, lorasQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + const loraEntity = lorasQueryData?.entities[v[0]]; + if (!loraEntity) { + return; + } + v[0] && dispatch(loraAdded(loraEntity)); + }, + [dispatch, lorasQueryData?.entities] + ); + + return ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + ); +}; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; + +export default ParamLoraSelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts new file mode 100644 index 0000000000..c9b290eb2d --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -0,0 +1,46 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; + +export type Lora = { + id: string; + name: string; + weight: number; +}; + +export const defaultLoRAConfig: Omit = { + weight: 1, +}; + +export type LoraState = { + loras: Record; +}; + +export const intialLoraState: LoraState = { + loras: {}, +}; + +export const loraSlice = createSlice({ + name: 'lora', + initialState: intialLoraState, + reducers: { + loraAdded: (state, action: PayloadAction) => { + const { name, id } = action.payload; + state.loras[id] = { id, name, ...defaultLoRAConfig }; + }, + loraRemoved: (state, action: PayloadAction) => { + const id = action.payload; + delete state.loras[id]; + }, + loraWeightChanged: ( + state, + action: PayloadAction<{ id: string; weight: number }> + ) => { + const { id, weight } = action.payload; + state.loras[id].weight = weight; + }, + }, +}); + +export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; + +export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index b3b91ccf5e..9925a48381 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -3,20 +3,22 @@ import { memo } from 'react'; import { InputFieldTemplate, InputFieldValue } from '../types/types'; import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent'; import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; -import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; -import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; -import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; -import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; -import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; -import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; +import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; +import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; +import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; +import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; +import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; +import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; +import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; -import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; -import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; -import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; +import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; +import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; +import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; type InputFieldComponentProps = { nodeId: string; @@ -152,6 +154,26 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'vae_model' && template.type === 'vae_model') { + return ( + + ); + } + + if (type === 'lora_model' && template.type === 'lora_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( @@ -34,23 +32,6 @@ const ImageInputFieldComponent = ( isSuccess, } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); - const handleDrop = useCallback( - ({ image_name }: ImageDTO) => { - if (field.value?.image_name === image_name) { - return; - } - - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: { image_name }, - }) - ); - }, - [dispatch, field.name, field.value, nodeId] - ); - const handleReset = useCallback(() => { dispatch( fieldValueChanged({ @@ -71,15 +52,14 @@ const ImageInputFieldComponent = ( } }, [field.name, imageDTO, nodeId]); - const droppableData = useMemo(() => { - if (imageDTO) { - return { - id: `node-${nodeId}-${field.name}`, - actionType: 'SET_NODES_IMAGE', - context: { nodeId, fieldName: field.name }, - }; - } - }, [field.name, imageDTO, nodeId]); + const droppableData = useMemo( + () => ({ + id: `node-${nodeId}-${field.name}`, + actionType: 'SET_NODES_IMAGE', + context: { nodeId, fieldName: field.name }, + }), + [field.name, nodeId] + ); const postUploadAction = useMemo( () => ({ diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx new file mode 100644 index 0000000000..02cdfd454d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -0,0 +1,102 @@ +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + VaeModelInputFieldTemplate, + VaeModelInputFieldValue, +} from 'features/nodes/types/types'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { forEach, isString } from 'lodash-es'; +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; + +const LoRAModelInputFieldComponent = ( + props: FieldComponentProps< + VaeModelInputFieldValue, + VaeModelInputFieldTemplate + > +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: loraModels } = useGetLoRAModelsQuery(); + + const selectedModel = useMemo( + () => loraModels?.entities[field.value ?? loraModels.ids[0]], + [loraModels?.entities, loraModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!loraModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(loraModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [loraModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && loraModels?.ids.includes(field.value)) { + return; + } + + const firstLora = loraModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, loraModels?.ids]); + + return ( + + ); +}; + +export default memo(LoRAModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index 741662655f..ee739e1002 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -6,13 +6,13 @@ import { ModelInputFieldValue, } from 'features/nodes/types/types'; -import { memo, useCallback, useEffect, useMemo } from 'react'; -import { FieldComponentProps } from './types'; -import { forEach, isString } from 'lodash-es'; -import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import { forEach, isString } from 'lodash-es'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; const ModelInputFieldComponent = ( props: FieldComponentProps @@ -22,18 +22,16 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: pipelineModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const data = useMemo(() => { - if (!pipelineModels) { + if (!mainModels) { return []; } const data: SelectItem[] = []; - forEach(pipelineModels.entities, (model, id) => { + forEach(mainModels.entities, (model, id) => { if (!model) { return; } @@ -46,11 +44,11 @@ const ModelInputFieldComponent = ( }); return data; - }, [pipelineModels]); + }, [mainModels]); const selectedModel = useMemo( - () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], - [pipelineModels?.entities, pipelineModels?.ids, field.value] + () => mainModels?.entities[field.value ?? mainModels.ids[0]], + [mainModels?.entities, mainModels?.ids, field.value] ); const handleValueChanged = useCallback( @@ -71,18 +69,18 @@ const ModelInputFieldComponent = ( ); useEffect(() => { - if (field.value && pipelineModels?.ids.includes(field.value)) { + if (field.value && mainModels?.ids.includes(field.value)) { return; } - const firstModel = pipelineModels?.ids[0]; + const firstModel = mainModels?.ids[0]; if (!isString(firstModel)) { return; } handleValueChanged(firstModel); - }, [field.value, handleValueChanged, pipelineModels?.ids]); + }, [field.value, handleValueChanged, mainModels?.ids]); return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const selectedModel = useMemo( + () => vaeModels?.entities[field.value ?? vaeModels.ids[0]], + [vaeModels?.entities, vaeModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!vaeModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(vaeModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [vaeModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && vaeModels?.ids.includes(field.value)) { + return; + } + handleValueChanged('auto'); + }, [field.value, handleValueChanged, vaeModels?.ids]); + + return ( + + ); +}; + +export default memo(VaeModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index ffc93db2ba..4fa69c626b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,8 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; +import { RgbaColor } from 'react-colorful'; import { addEdge, applyEdgeChanges, @@ -11,12 +14,9 @@ import { NodeChange, OnConnectStartParams, } from 'reactflow'; -import { ImageField } from 'services/api/types'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { ImageField } from 'services/api/types'; import { InvocationTemplate, InvocationValue } from '../types/types'; -import { RgbaColor } from 'react-colorful'; -import { RootState } from 'app/store/store'; -import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es'; export type NodesState = { nodes: Node[]; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 9f6124c9d4..5fe780a286 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -17,6 +17,8 @@ export const FIELD_TYPE_MAP: Record = { ClipField: 'clip', VaeField: 'vae', model: 'model', + vae_model: 'vae_model', + lora_model: 'lora_model', array: 'array', item: 'item', ColorField: 'color', @@ -116,6 +118,18 @@ export const FIELDS: Record = { title: 'Model', description: 'Models are models.', }, + vae_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'VAE', + description: 'Models are models.', + }, + lora_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'LoRA', + description: 'Models are models.', + }, array: { color: 'gray', colorCssVar: getColorTokenCssVariable('gray'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 9498bbd5d5..3de8cae9ff 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -64,6 +64,8 @@ export type FieldType = | 'vae' | 'control' | 'model' + | 'vae_model' + | 'lora_model' | 'array' | 'item' | 'color' @@ -91,6 +93,8 @@ export type InputFieldValue = | ControlInputFieldValue | EnumInputFieldValue | ModelInputFieldValue + | VaeModelInputFieldValue + | LoRAModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -116,6 +120,8 @@ export type InputFieldTemplate = | ControlInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate + | VaeModelInputFieldTemplate + | LoRAModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -228,6 +234,16 @@ export type ModelInputFieldValue = FieldValueBase & { value?: string; }; +export type VaeModelInputFieldValue = FieldValueBase & { + type: 'vae_model'; + value?: string; +}; + +export type LoRAModelInputFieldValue = FieldValueBase & { + type: 'lora_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -305,6 +321,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { type: 'conditioning'; }; +export type UNetInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'unet'; +}; + +export type ClipInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'clip'; +}; + +export type VaeInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'vae'; +}; + export type ControlInputFieldTemplate = InputFieldTemplateBase & { default: undefined; type: 'control'; @@ -322,6 +353,16 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & { type: 'model'; }; +export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'vae_model'; +}; + +export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'lora_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index 11ceb23763..5c4d67ebd3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { filter } from 'lodash-es'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { NonNullableGraph } from '../types/types'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; @@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = ( ): void => { const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; - const validControlNets = filter( - controlNets, - (c) => - c.isEnabled && - (Boolean(c.processedControlImage) || - (c.processorType === 'none' && Boolean(c.controlImage))) - ); + const validControlNets = getValidControlNets(controlNets); if (isControlNetEnabled && Boolean(validControlNets.length)) { if (validControlNets.length > 1) { diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 6f971dd60b..1c2dbc0c3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -3,27 +3,29 @@ import { OpenAPIV3 } from 'openapi-types'; import { FIELD_TYPE_MAP } from '../types/constants'; import { isSchemaObject } from '../types/typeGuards'; import { - BooleanInputFieldTemplate, - EnumInputFieldTemplate, - FloatInputFieldTemplate, - ImageInputFieldTemplate, - IntegerInputFieldTemplate, - LatentsInputFieldTemplate, - ConditioningInputFieldTemplate, - UNetInputFieldTemplate, - ClipInputFieldTemplate, - VaeInputFieldTemplate, - ControlInputFieldTemplate, - StringInputFieldTemplate, - ModelInputFieldTemplate, ArrayInputFieldTemplate, - ItemInputFieldTemplate, + BooleanInputFieldTemplate, + ClipInputFieldTemplate, ColorInputFieldTemplate, - InputFieldTemplateBase, - OutputFieldTemplate, - TypeHints, + ConditioningInputFieldTemplate, + ControlInputFieldTemplate, + EnumInputFieldTemplate, FieldType, + FloatInputFieldTemplate, ImageCollectionInputFieldTemplate, + ImageInputFieldTemplate, + InputFieldTemplateBase, + IntegerInputFieldTemplate, + ItemInputFieldTemplate, + LatentsInputFieldTemplate, + LoRAModelInputFieldTemplate, + ModelInputFieldTemplate, + OutputFieldTemplate, + StringInputFieldTemplate, + TypeHints, + UNetInputFieldTemplate, + VaeInputFieldTemplate, + VaeModelInputFieldTemplate, } from '../types/types'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -175,6 +177,36 @@ const buildModelInputFieldTemplate = ({ return template; }; +const buildVaeModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): VaeModelInputFieldTemplate => { + const template: VaeModelInputFieldTemplate = { + ...baseField, + type: 'vae_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildLoRAModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { + const template: LoRAModelInputFieldTemplate = { + ...baseField, + type: 'lora_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -441,6 +473,12 @@ export const buildInputFieldTemplate = ( if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } + if (['vae_model'].includes(fieldType)) { + return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); + } + if (['lora_model'].includes(fieldType)) { + return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index e05ef404c0..950038b691 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -75,6 +75,14 @@ export const buildInputFieldValue = ( if (template.type === 'model') { fieldValue.value = undefined; } + + if (template.type === 'vae_model') { + fieldValue.value = undefined; + } + + if (template.type === 'lora_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts new file mode 100644 index 0000000000..9712ef4d5f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -0,0 +1,148 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { LoraLoaderInvocation } from 'services/api/types'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { + LORA_LOADER, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, +} from './constants'; + +export const addLoRAsToGraph = ( + graph: NonNullableGraph, + state: RootState, + baseNodeId: string +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + + if (loraCount > 0) { + // remove any existing connections from main model loader, we need to insert the lora nodes + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === MAIN_MODEL_LOADER && + ['unet', 'clip'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { id, name, weight } = lora; + const loraField = modelIdToLoRAModelField(id); + const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( + '.', + '_' + )}`; + + const loraLoaderNode: LoraLoaderInvocation = { + type: 'lora_loader', + id: currentLoraNodeId, + lora: loraField, + weight, + }; + + graph.nodes[currentLoraNodeId] = loraLoaderNode; + + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts new file mode 100644 index 0000000000..4dd3d644ee --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -0,0 +1,68 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; +import { + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, + INPAINT, + INPAINT_GRAPH, + LATENTS_TO_IMAGE, + MAIN_MODEL_LOADER, + TEXT_TO_IMAGE_GRAPH, + VAE_LOADER, +} from './constants'; + +export const addVAEToGraph = ( + graph: NonNullableGraph, + state: RootState +): void => { + const { vae: vaeId } = state.generation; + const vae_model = modelIdToVAEModelField(vaeId); + + if (vaeId !== 'auto') { + graph.nodes[VAE_LOADER] = { + type: 'vae_loader', + id: VAE_LOADER, + vae_model, + }; + } + + if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }); + } + + if (graph.id === IMAGE_TO_IMAGE_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'vae', + }, + }); + } + + if (graph.id === INPAINT_GRAPH) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: INPAINT, + field: 'vae', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 49bab291f7..1843efef84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -1,31 +1,27 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageResizeInvocation, ImageToLatentsInvocation, - RandomIntInvocation, - RangeOfSizeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { - ITERATE, + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + LATENTS_TO_LATENTS, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - RANDOM_INT, - RANGE_OF_SIZE, - IMAGE_TO_IMAGE_GRAPH, - IMAGE_TO_LATENTS, - LATENTS_TO_LATENTS, RESIZE, } from './constants'; -import { set } from 'lodash-es'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -52,7 +48,7 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -81,9 +77,9 @@ export const buildCanvasImageToImageGraph = ( type: 'noise', id: NOISE, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -110,7 +106,7 @@ export const buildCanvasImageToImageGraph = ( edges: [ { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -120,7 +116,7 @@ export const buildCanvasImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -128,16 +124,6 @@ export const buildCanvasImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -170,17 +156,7 @@ export const buildCanvasImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -277,6 +253,11 @@ export const buildCanvasImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 74bd12a742..c4f9415067 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -1,23 +1,25 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, InpaintInvocation, RandomIntInvocation, RangeOfSizeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { + INPAINT, + INPAINT_GRAPH, ITERATE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING, RANDOM_INT, RANGE_OF_SIZE, - INPAINT_GRAPH, - INPAINT, } from './constants'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -55,7 +57,7 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); const graph: NonNullableGraph = { id: INPAINT_GRAPH, @@ -101,9 +103,9 @@ export const buildCanvasInpaintGraph = ( id: NEGATIVE_CONDITIONING, prompt: negativePrompt, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [RANGE_OF_SIZE]: { @@ -142,7 +144,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -152,7 +154,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -162,7 +164,7 @@ export const buildCanvasInpaintGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -170,16 +172,6 @@ export const buildCanvasInpaintGraph = ( field: 'unet', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: INPAINT, - field: 'vae', - }, - }, { source: { node_id: RANGE_OF_SIZE, @@ -203,6 +195,11 @@ export const buildCanvasInpaintGraph = ( ], }; + addLoRAsToGraph(graph, state, INPAINT); + + // Add VAE + addVAEToGraph(graph, state); + // handle seed if (shouldRandomizeSeed) { // Random int node to generate the starting seed diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index b15b2cd192..976ea4fd01 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -1,21 +1,19 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { - ITERATE, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - RANDOM_INT, - RANGE_OF_SIZE, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; /** * Builds the Canvas tab's Text to Image graph. @@ -38,7 +36,7 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -76,9 +74,9 @@ export const buildCanvasTextToImageGraph = ( scheduler, steps, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -109,7 +107,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -119,7 +117,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -129,7 +127,7 @@ export const buildCanvasTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -147,16 +145,6 @@ export const buildCanvasTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -170,6 +158,11 @@ export const buildCanvasTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index ca0a2e4dd9..fe6d1292e4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -1,28 +1,30 @@ +import { log } from 'app/logging/useLogger'; import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageCollectionInvocation, ImageResizeInvocation, ImageToLatentsInvocation, IterateInvocation, } from 'services/api/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { log } from 'app/logging/useLogger'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { + IMAGE_COLLECTION, + IMAGE_COLLECTION_ITERATE, + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + LATENTS_TO_LATENTS, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, - IMAGE_TO_IMAGE_GRAPH, - IMAGE_TO_LATENTS, - LATENTS_TO_LATENTS, RESIZE, - IMAGE_COLLECTION, - IMAGE_COLLECTION_ITERATE, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -69,7 +71,7 @@ export const buildLinearImageToImageGraph = ( throw new Error('No initial image found in state'); } - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { @@ -89,9 +91,9 @@ export const buildLinearImageToImageGraph = ( type: 'noise', id: NOISE, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -118,7 +120,7 @@ export const buildLinearImageToImageGraph = ( edges: [ { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -128,7 +130,7 @@ export const buildLinearImageToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -136,16 +138,6 @@ export const buildLinearImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -176,19 +168,10 @@ export const buildLinearImageToImageGraph = ( field: 'noise', }, }, + { source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -323,6 +306,11 @@ export const buildLinearImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index 216c5c8c67..04dccf4983 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,17 +1,19 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; +import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, - PIPELINE_MODEL_LOADER, + MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, } from './constants'; -import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; -import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; export const buildLinearTextToImageGraph = ( state: RootState @@ -27,7 +29,7 @@ export const buildLinearTextToImageGraph = ( height, } = state.generation; - const model = modelIdToPipelineModelField(modelId); + const model = modelIdToMainModelField(modelId); /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the @@ -65,9 +67,9 @@ export const buildLinearTextToImageGraph = ( scheduler, steps, }, - [PIPELINE_MODEL_LOADER]: { - type: 'pipeline_model_loader', - id: PIPELINE_MODEL_LOADER, + [MAIN_MODEL_LOADER]: { + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, model, }, [LATENTS_TO_IMAGE]: { @@ -98,7 +100,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -108,7 +110,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'clip', }, destination: { @@ -118,7 +120,7 @@ export const buildLinearTextToImageGraph = ( }, { source: { - node_id: PIPELINE_MODEL_LOADER, + node_id: MAIN_MODEL_LOADER, field: 'unet', }, destination: { @@ -136,16 +138,6 @@ export const buildLinearTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: PIPELINE_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -159,6 +151,11 @@ export const buildLinearTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + + // Add Custom VAE Support + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 091899a21a..12a567b009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,10 +1,12 @@ -import { Graph } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; +import { cloneDeep, omit, reduce } from 'lodash-es'; +import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; -import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; +import { v4 as uuidv4 } from 'uuid'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; /** * We need to do special handling for some fields @@ -27,7 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => { if (field.type === 'model') { if (field.value) { - return modelIdToPipelineModelField(field.value); + return modelIdToMainModelField(field.value); + } + } + + if (field.type === 'vae_model') { + if (field.value) { + return modelIdToVAEModelField(field.value); + } + } + + if (field.type === 'lora_model') { + if (field.value) { + return modelIdToLoRAModelField(field.value); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index b0b1edde30..7aace48def 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,7 +7,9 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader'; +export const MAIN_MODEL_LOADER = 'main_model_loader'; +export const VAE_LOADER = 'vae_loader'; +export const LORA_LOADER = 'lora_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts new file mode 100644 index 0000000000..052b58484b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts @@ -0,0 +1,12 @@ +import { BaseModelType, LoRAModelField } from 'services/api/types'; + +export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const field: LoRAModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts new file mode 100644 index 0000000000..6bb0f776b2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToMainModelField.ts @@ -0,0 +1,16 @@ +import { BaseModelType, MainModelField } from 'services/api/types'; + +/** + * Crudely converts a model id to a main model field + * TODO: Make better + */ +export const modelIdToMainModelField = (modelId: string): MainModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: MainModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts deleted file mode 100644 index 0941255181..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { BaseModelType, PipelineModelField } from 'services/api/types'; - -/** - * Crudely converts a model id to a pipeline model field - * TODO: Make better - */ -export const modelIdToPipelineModelField = ( - modelId: string -): PipelineModelField => { - const [base_model, model_type, model_name] = modelId.split('/'); - - const field: PipelineModelField = { - base_model: base_model as BaseModelType, - model_name, - }; - - return field; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts new file mode 100644 index 0000000000..0cb608a936 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts @@ -0,0 +1,16 @@ +import { BaseModelType, VAEModelField } from 'services/api/types'; + +/** + * Crudely converts a model id to a main model field + * TODO: Make better + */ +export const modelIdToVAEModelField = (modelId: string): VAEModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: VAEModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx index fea0d8330a..b9cc8511aa 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx @@ -1,20 +1,15 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; +import { Flex } from '@chakra-ui/react'; import IAICollapse from 'common/components/IAICollapse'; import { memo } from 'react'; -import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; +import { useTranslation } from 'react-i18next'; import ParamBoundingBoxHeight from './ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; const ParamBoundingBoxCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx index ed01da9876..a531eba57f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx @@ -1,4 +1,4 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; +import { Flex } from '@chakra-ui/react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse'; import ParamInfillMethod from './ParamInfillMethod'; import ParamInfillTilesize from './ParamInfillTilesize'; import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing'; -import ParamScaledWidth from './ParamScaledWidth'; import ParamScaledHeight from './ParamScaledHeight'; +import ParamScaledWidth from './ParamScaledWidth'; const ParamInfillCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx index 992e8b6d02..88d839fa15 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx @@ -1,22 +1,16 @@ +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import ParamSeamBlur from './ParamSeamBlur'; import ParamSeamSize from './ParamSeamSize'; import ParamSeamSteps from './ParamSeamSteps'; import ParamSeamStrength from './ParamSeamStrength'; -import { useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; const ParamSeamCorrectionCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 06c6108dcb..59bf7542eb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -1,41 +1,45 @@ import { Divider, Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { Fragment, memo, useCallback } from 'react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAICollapse from 'common/components/IAICollapse'; +import ControlNet from 'features/controlNet/components/ControlNet'; +import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { controlNetAdded, controlNetSelector, - isControlNetEnabledToggled, } from 'features/controlNet/store/controlNetSlice'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; -import { v4 as uuidv4 } from 'uuid'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import IAIButton from 'common/components/IAIButton'; -import ControlNet from 'features/controlNet/components/ControlNet'; +import { map } from 'lodash-es'; +import { Fragment, memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( controlNetSelector, (controlNet) => { const { controlNets, isEnabled } = controlNet; - return { controlNetsArray: map(controlNets), isEnabled }; + const validControlNets = getValidControlNets(controlNets); + + const activeLabel = + isEnabled && validControlNets.length > 0 + ? `${validControlNets.length} Active` + : undefined; + + return { controlNetsArray: map(controlNets), activeLabel }; }, defaultSelectorOptions ); const ParamControlNetCollapse = () => { const { t } = useTranslation(); - const { controlNetsArray, isEnabled } = useAppSelector(selector); + const { controlNetsArray, activeLabel } = useAppSelector(selector); const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const dispatch = useAppDispatch(); - const handleClickControlNetToggle = useCallback(() => { - dispatch(isControlNetEnabledToggled()); - }, [dispatch]); - const handleClickedAddControlNet = useCallback(() => { dispatch(controlNetAdded({ controlNetId: uuidv4() })); }, [dispatch]); @@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => { } return ( - + + {controlNetsArray.map((c, i) => ( {i > 0 && } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx index 111e3d3ae8..d32ff960d5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; @@ -27,7 +28,8 @@ const selector = createSelector( shouldUseSliders, shift, }; - } + }, + defaultSelectorOptions ); const ParamCFGScale = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx index 9501c8b475..6939ede424 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setHeight } from 'features/parameters/store/generationSlice'; @@ -25,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamHeightProps = Omit< diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx index a8cdabc8c9..1e203a1e45 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx @@ -1,37 +1,38 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setIterations } from 'features/parameters/store/generationSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { hotkeysSelector } from 'features/ui/store/hotkeysSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector([stateSelector], (state) => { - const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = - state.config.sd.iterations; - const { iterations } = state.generation; - const { shouldUseSliders } = state.ui; - const isDisabled = - state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; +const selector = createSelector( + [stateSelector], + (state) => { + const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = + state.config.sd.iterations; + const { iterations } = state.generation; + const { shouldUseSliders } = state.ui; + const isDisabled = + state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; - const step = state.hotkeys.shift ? fineStep : coarseStep; + const step = state.hotkeys.shift ? fineStep : coarseStep; - return { - iterations, - initial, - min, - sliderMax, - inputMax, - step, - shouldUseSliders, - isDisabled, - }; -}); + return { + iterations, + initial, + min, + sliderMax, + inputMax, + step, + shouldUseSliders, + isDisabled, + }; + }, + defaultSelectorOptions +); const ParamIterations = () => { const { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx similarity index 60% rename from invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx rename to invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx index 5092893eed..1c704a86ef 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAE.tsx @@ -1,19 +1,19 @@ import { Box, Flex } from '@chakra-ui/react'; import ModelSelect from 'features/system/components/ModelSelect'; +import VAESelect from 'features/system/components/VAESelect'; import { memo } from 'react'; -import ParamScheduler from './ParamScheduler'; -const ParamSchedulerAndModel = () => { +const ParamModelandVAE = () => { return ( - - - + + + ); }; -export default memo(ParamSchedulerAndModel); +export default memo(ParamModelandVAE); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx index f43cdd425b..d939113c7c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; @@ -33,7 +34,8 @@ const selector = createSelector( step, shouldUseSliders, }; - } + }, + defaultSelectorOptions ); const ParamSteps = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx index b7d63038d1..b4121184b5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { IAIFullSliderProps } from 'common/components/IAISlider'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setWidth } from 'features/parameters/store/generationSlice'; import { configSelector } from 'features/system/store/configSelectors'; @@ -26,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamWidthProps = Omit; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx index b4b077ad6c..fa8606d610 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx @@ -1,37 +1,39 @@ import { Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; -import { ParamHiresStrength } from './ParamHiresStrength'; -import { setHiresFix } from 'features/parameters/store/postprocessingSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { ParamHiresStrength } from './ParamHiresStrength'; +import { ParamHiresToggle } from './ParamHiresToggle'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamHiresCollapse = () => { const { t } = useTranslation(); - const hiresFix = useAppSelector( - (state: RootState) => state.postprocessing.hiresFix - ); + const { activeLabel } = useAppSelector(selector); const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setHiresFix(!hiresFix)); - if (!isHiresEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx index 0fc600e9e8..f8e6f22aa4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx @@ -23,7 +23,6 @@ export const ParamHiresToggle = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx index adb76d8da0..4dea1dad4f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx @@ -1,27 +1,33 @@ -import { useTranslation } from 'react-i18next'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import ParamPerlinNoise from './ParamPerlinNoise'; -import ParamNoiseThreshold from './ParamNoiseThreshold'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamNoiseThreshold from './ParamNoiseThreshold'; +import { ParamNoiseToggle } from './ParamNoiseToggle'; +import ParamPerlinNoise from './ParamPerlinNoise'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings } = state.generation; + return { + activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined, + }; + }, + defaultSelectorOptions +); const ParamNoiseCollapse = () => { const { t } = useTranslation(); const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; - const shouldUseNoiseSettings = useAppSelector( - (state: RootState) => state.generation.shouldUseNoiseSettings - ); - - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings)); + const { activeLabel } = useAppSelector(selector); if (!isNoiseEnabled) { return null; @@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => { return ( + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx index e339734992..3abb7532b4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx @@ -1,18 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setThreshold } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, threshold } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + threshold, + }; + }, + defaultSelectorOptions +); + export default function ParamNoiseThreshold() { const dispatch = useAppDispatch(); - const threshold = useAppSelector( - (state: RootState) => state.generation.threshold - ); + const { threshold, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + const dispatch = useAppDispatch(); + + const shouldUseNoiseSettings = useAppSelector( + (state: RootState) => state.generation.shouldUseNoiseSettings + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldUseNoiseSettings(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx index ad710eae54..afd676223c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx @@ -1,16 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setPerlin } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, perlin } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + perlin, + }; + }, + defaultSelectorOptions +); + export default function ParamPerlinNoise() { const dispatch = useAppDispatch(); - const perlin = useAppSelector((state: RootState) => state.generation.perlin); + const { perlin, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + if (seamlessXAxis && seamlessYAxis) { + return 'X & Y'; + } + + if (seamlessXAxis) { + return 'X'; + } + + if (seamlessYAxis) { + return 'Y'; + } +}; const selector = createSelector( generationSelector, (generation) => { - const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation; + const { seamlessXAxis, seamlessYAxis } = generation; - return { shouldUseSeamless, seamlessXAxis, seamlessYAxis }; + const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis); + return { activeLabel }; }, defaultSelectorOptions ); const ParamSeamlessCollapse = () => { const { t } = useTranslation(); - const { shouldUseSeamless } = useAppSelector(selector); + const { activeLabel } = useAppSelector(selector); const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless)); - if (!isSeamlessEnabled) { return null; } @@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx index 59bdb39be1..f2ddd19768 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx @@ -1,39 +1,39 @@ -import { memo } from 'react'; import { Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import ParamSymmetryHorizontal from './ParamSymmetryHorizontal'; import ParamSymmetryVertical from './ParamSymmetryVertical'; -import { useTranslation } from 'react-i18next'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { useTranslation } from 'react-i18next'; +import ParamSymmetryToggle from './ParamSymmetryToggle'; + +const selector = createSelector( + stateSelector, + (state) => ({ + activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined, + }), + defaultSelectorOptions +); const ParamSymmetryCollapse = () => { const { t } = useTranslation(); - const shouldUseSymmetry = useAppSelector( - (state: RootState) => state.generation.shouldUseSymmetry - ); + const { activeLabel } = useAppSelector(selector); const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry)); - if (!isSymmetryEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx index 7cc17c045e..59386ff526 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx @@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() { return ( dispatch(setShouldUseSymmetry(e.target.checked))} /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx index 1564bd64e5..3cdfc3a06b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx @@ -1,39 +1,42 @@ -import ParamVariationWeights from './ParamVariationWeights'; -import ParamVariationAmount from './ParamVariationAmount'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamVariationAmount from './ParamVariationAmount'; +import { ParamVariationToggle } from './ParamVariationToggle'; +import ParamVariationWeights from './ParamVariationWeights'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.generation.shouldGenerateVariations + ? 'Enabled' + : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamVariationCollapse = () => { const { t } = useTranslation(); - const shouldGenerateVariations = useAppSelector( - (state: RootState) => state.generation.shouldGenerateVariations - ); + const { activeLabel } = useAppSelector(selector); const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldGenerateVariations(!shouldGenerateVariations)); - if (!isVariationEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx new file mode 100644 index 0000000000..1c05468de0 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx @@ -0,0 +1,27 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; +import { ChangeEvent } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const ParamVariationToggle = () => { + const dispatch = useAppDispatch(); + + const shouldGenerateVariations = useAppSelector( + (state: RootState) => state.generation.shouldGenerateVariations + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldGenerateVariations(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index c8e65314da..960a41bb45 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -14,6 +14,7 @@ import { SeedParam, StepsParam, StrengthParam, + VAEParam, WidthParam, } from './parameterZodSchemas'; @@ -47,7 +48,7 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; - shouldUseSeamless: boolean; + vae: VAEParam; seamlessXAxis: boolean; seamlessYAxis: boolean; } @@ -81,9 +82,9 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', - shouldUseSeamless: false, - seamlessXAxis: true, - seamlessYAxis: true, + vae: '', + seamlessXAxis: false, + seamlessYAxis: false, }; const initialState: GenerationState = initialGenerationState; @@ -141,9 +142,6 @@ export const generationSlice = createSlice({ setImg2imgStrength: (state, action: PayloadAction) => { state.img2imgStrength = action.payload; }, - setSeamless: (state, action: PayloadAction) => { - state.shouldUseSeamless = action.payload; - }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -216,6 +214,9 @@ export const generationSlice = createSlice({ modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, + vaeSelected: (state, action: PayloadAction) => { + state.vae = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(configChanged, (state, action) => { @@ -260,8 +261,8 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, + vaeSelected, setShouldUseNoiseSettings, - setSeamless, setSeamlessXAxis, setSeamlessYAxis, } = generationSlice.actions; diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 48eb309e7d..12d77beeb9 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -135,6 +135,15 @@ export const zModel = z.string(); * Type alias for model parameter, inferred from its zod schema */ export type ModelParam = z.infer; +/** + * Zod schema for VAE parameter + * TODO: Make this a dynamically generated enum? + */ +export const zVAE = z.string(); +/** + * Type alias for model parameter, inferred from its zod schema + */ +export type VAEParam = z.infer; /** * Validates/type-guards a value as a model parameter */ diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx deleted file mode 100644 index bd0d0e5d3a..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddModel.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import { - Button, - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Text, - useDisclosure, -} from '@chakra-ui/react'; - -import IAIButton from 'common/components/IAIButton'; - -import { FaArrowLeft, FaPlus } from 'react-icons/fa'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useTranslation } from 'react-i18next'; - -import type { RootState } from 'app/store/store'; -import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; -import AddCheckpointModel from './AddCheckpointModel'; -import AddDiffusersModel from './AddDiffusersModel'; -import IAIIconButton from 'common/components/IAIIconButton'; - -function AddModelBox({ - text, - onClick, -}: { - text: string; - onClick?: () => void; -}) { - return ( - - {text} - - ); -} - -export default function AddModel() { - const { isOpen, onOpen, onClose } = useDisclosure(); - - const addNewModelUIOption = useAppSelector( - (state: RootState) => state.ui.addNewModelUIOption - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const addModelModalClose = () => { - onClose(); - dispatch(setAddNewModelUIOption(null)); - }; - - return ( - <> - - - - {t('modelManager.addNew')} - - - - - - - {t('modelManager.addNewModel')} - {addNewModelUIOption !== null && ( - dispatch(setAddNewModelUIOption(null))} - position="absolute" - variant="ghost" - zIndex={1} - size="sm" - insetInlineEnd={12} - top={2} - icon={} - /> - )} - - - {addNewModelUIOption == null && ( - - dispatch(setAddNewModelUIOption('ckpt'))} - /> - dispatch(setAddNewModelUIOption('diffusers'))} - /> - - )} - {addNewModelUIOption == 'ckpt' && } - {addNewModelUIOption == 'diffusers' && } - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx deleted file mode 100644 index b860a0848c..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/CheckpointModelEdit.tsx +++ /dev/null @@ -1,339 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; - -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAINumberInput from 'common/components/IAINumberInput'; -import { useEffect, useState } from 'react'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; - -import { - Flex, - FormControl, - FormLabel, - HStack, - Text, - VStack, -} from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { Field, Formik } from 'formik'; -import { useTranslation } from 'react-i18next'; - -import type { InvokeModelConfigProps } from 'app/types/invokeai'; -import type { RootState } from 'app/store/store'; -import type { FieldInputProps, FormikProps } from 'formik'; -import { isEqual, pickBy } from 'lodash-es'; -import ModelConvert from './ModelConvert'; -import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; -import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; -import IAIForm from 'common/components/IAIForm'; - -const selector = createSelector( - [systemSelector], - (system) => { - const { openModel, model_list } = system; - return { - model_list, - openModel, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -const MIN_MODEL_SIZE = 64; -const MAX_MODEL_SIZE = 2048; - -export default function CheckpointModelEdit() { - const { openModel, model_list } = useAppSelector(selector); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const [editModelFormValues, setEditModelFormValues] = - useState({ - name: '', - description: '', - config: 'configs/stable-diffusion/v1-inference.yaml', - weights: '', - vae: '', - width: 512, - height: 512, - default: false, - format: 'ckpt', - }); - - useEffect(() => { - if (openModel) { - const retrievedModel = pickBy(model_list, (_val, key) => { - return isEqual(key, openModel); - }); - setEditModelFormValues({ - name: openModel, - description: retrievedModel[openModel]?.description, - config: retrievedModel[openModel]?.config, - weights: retrievedModel[openModel]?.weights, - vae: retrievedModel[openModel]?.vae, - width: retrievedModel[openModel]?.width, - height: retrievedModel[openModel]?.height, - default: retrievedModel[openModel]?.default, - format: 'ckpt', - }); - } - }, [model_list, openModel]); - - const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => { - dispatch( - addNewModel({ - ...values, - width: Number(values.width), - height: Number(values.height), - }) - ); - }; - - return openModel ? ( - - - - {openModel} - - - - - - {({ handleSubmit, errors, touched }) => ( - - - {/* Description */} - - - {t('modelManager.description')} - - - - {!!errors.description && touched.description ? ( - - {errors.description} - - ) : ( - - {t('modelManager.descriptionValidationMsg')} - - )} - - - - {/* Config */} - - - {t('modelManager.config')} - - - - {!!errors.config && touched.config ? ( - {errors.config} - ) : ( - - {t('modelManager.configValidationMsg')} - - )} - - - - {/* Weights */} - - - {t('modelManager.modelLocation')} - - - - {!!errors.weights && touched.weights ? ( - - {errors.weights} - - ) : ( - - {t('modelManager.modelLocationValidationMsg')} - - )} - - - - {/* VAE */} - - - {t('modelManager.vaeLocation')} - - - - {!!errors.vae && touched.vae ? ( - {errors.vae} - ) : ( - - {t('modelManager.vaeLocationValidationMsg')} - - )} - - - - - {/* Width */} - - - {t('modelManager.width')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.width && touched.width ? ( - - {errors.width} - - ) : ( - - {t('modelManager.widthValidationMsg')} - - )} - - - - {/* Height */} - - - {t('modelManager.height')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.height && touched.height ? ( - - {errors.height} - - ) : ( - - {t('modelManager.heightValidationMsg')} - - )} - - - - - - {t('modelManager.updateModel')} - - - - )} - - - - ) : ( - - Pick A Model To Edit - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx deleted file mode 100644 index 81998e4976..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/DiffusersModelEdit.tsx +++ /dev/null @@ -1,281 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; - -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import { useEffect, useState } from 'react'; - -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { systemSelector } from 'features/system/store/systemSelectors'; - -import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { Field, Formik } from 'formik'; -import { useTranslation } from 'react-i18next'; - -import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai'; -import type { RootState } from 'app/store/store'; -import { isEqual, pickBy } from 'lodash-es'; -import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; -import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; -import IAIForm from 'common/components/IAIForm'; - -const selector = createSelector( - [systemSelector], - (system) => { - const { openModel, model_list } = system; - return { - model_list, - openModel, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -export default function DiffusersModelEdit() { - const { openModel, model_list } = useAppSelector(selector); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const [editModelFormValues, setEditModelFormValues] = - useState({ - name: '', - description: '', - repo_id: '', - path: '', - vae: { repo_id: '', path: '' }, - default: false, - format: 'diffusers', - }); - - useEffect(() => { - if (openModel) { - const retrievedModel = pickBy(model_list, (_val, key) => { - return isEqual(key, openModel); - }); - - setEditModelFormValues({ - name: openModel, - description: retrievedModel[openModel]?.description, - path: - retrievedModel[openModel]?.path && - retrievedModel[openModel]?.path !== 'None' - ? retrievedModel[openModel]?.path - : '', - repo_id: - retrievedModel[openModel]?.repo_id && - retrievedModel[openModel]?.repo_id !== 'None' - ? retrievedModel[openModel]?.repo_id - : '', - vae: { - repo_id: retrievedModel[openModel]?.vae?.repo_id - ? retrievedModel[openModel]?.vae?.repo_id - : '', - path: retrievedModel[openModel]?.vae?.path - ? retrievedModel[openModel]?.vae?.path - : '', - }, - default: retrievedModel[openModel]?.default, - format: 'diffusers', - }); - } - }, [model_list, openModel]); - - const editModelFormSubmitHandler = ( - values: InvokeDiffusersModelConfigProps - ) => { - const diffusersModelToEdit = values; - - if (values.path === '') delete diffusersModelToEdit.path; - if (values.repo_id === '') delete diffusersModelToEdit.repo_id; - if (values.vae.path === '') delete diffusersModelToEdit.vae.path; - if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id; - - dispatch(addNewModel(values)); - }; - - return openModel ? ( - - - - {openModel} - - - - - {({ handleSubmit, errors, touched }) => ( - - - {/* Description */} - - - {t('modelManager.description')} - - - - {!!errors.description && touched.description ? ( - - {errors.description} - - ) : ( - - {t('modelManager.descriptionValidationMsg')} - - )} - - - - {/* Path */} - - - {t('modelManager.modelLocation')} - - - - {!!errors.path && touched.path ? ( - {errors.path} - ) : ( - - {t('modelManager.modelLocationValidationMsg')} - - )} - - - - {/* Repo ID */} - - - {t('modelManager.repo_id')} - - - - {!!errors.repo_id && touched.repo_id ? ( - - {errors.repo_id} - - ) : ( - - {t('modelManager.repoIDValidationMsg')} - - )} - - - - {/* VAE Path */} - - - {t('modelManager.vaeLocation')} - - - - {!!errors.vae?.path && touched.vae?.path ? ( - - {errors.vae?.path} - - ) : ( - - {t('modelManager.vaeLocationValidationMsg')} - - )} - - - - {/* VAE Repo ID */} - - - {t('modelManager.vaeRepoID')} - - - - {!!errors.vae?.repo_id && touched.vae?.repo_id ? ( - - {errors.vae?.repo_id} - - ) : ( - - {t('modelManager.vaeRepoIDValidationMsg')} - - )} - - - - - {t('modelManager.updateModel')} - - - - )} - - - - ) : ( - - Pick A Model To Edit - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx deleted file mode 100644 index 219d49d4ee..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/MergeModels.tsx +++ /dev/null @@ -1,313 +0,0 @@ -import { - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Radio, - RadioGroup, - Text, - Tooltip, - useDisclosure, -} from '@chakra-ui/react'; -// import { mergeDiffusersModels } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAISelect from 'common/components/IAISelect'; -import { diffusersModelsSelector } from 'features/system/store/systemSelectors'; -import { useState } from 'react'; -import { useTranslation } from 'react-i18next'; -import * as InvokeAI from 'app/types/invokeai'; -import IAISlider from 'common/components/IAISlider'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; - -export default function MergeModels() { - const dispatch = useAppDispatch(); - - const { isOpen, onOpen, onClose } = useDisclosure(); - - const diffusersModels = useAppSelector(diffusersModelsSelector); - - const { t } = useTranslation(); - - const [modelOne, setModelOne] = useState( - Object.keys(diffusersModels)[0] - ); - const [modelTwo, setModelTwo] = useState( - Object.keys(diffusersModels)[1] - ); - const [modelThree, setModelThree] = useState('none'); - - const [mergedModelName, setMergedModelName] = useState(''); - const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); - - const [modelMergeInterp, setModelMergeInterp] = useState< - 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' - >('weighted_sum'); - - const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< - 'root' | 'custom' - >('root'); - - const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = - useState(''); - - const [modelMergeForce, setModelMergeForce] = useState(false); - - const modelOneList = Object.keys(diffusersModels).filter( - (model) => model !== modelTwo && model !== modelThree - ); - - const modelTwoList = Object.keys(diffusersModels).filter( - (model) => model !== modelOne && model !== modelThree - ); - - const modelThreeList = [ - { key: t('modelManager.none'), value: 'none' }, - ...Object.keys(diffusersModels) - .filter((model) => model !== modelOne && model !== modelTwo) - .map((model) => ({ key: model, value: model })), - ]; - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const mergeModelsHandler = () => { - let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; - modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); - - const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { - models_to_merge: modelsToMerge, - merged_model_name: - mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), - alpha: modelMergeAlpha, - interp: modelMergeInterp, - model_merge_save_path: - modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, - force: modelMergeForce, - }; - - dispatch(mergeDiffusersModels(mergeModelsInfo)); - }; - - return ( - <> - - - {t('modelManager.mergeModels')} - - - - - - - {t('modelManager.mergeModels')} - - - - - {t('modelManager.modelMergeHeaderHelp1')} - - {t('modelManager.modelMergeHeaderHelp2')} - - - - setModelOne(e.target.value)} - /> - setModelTwo(e.target.value)} - /> - { - if (e.target.value !== 'none') { - setModelThree(e.target.value); - setModelMergeInterp('add_difference'); - } else { - setModelThree('none'); - setModelMergeInterp('weighted_sum'); - } - }} - /> - - - setMergedModelName(e.target.value)} - /> - - - setModelMergeAlpha(v)} - withInput - withReset - handleReset={() => setModelMergeAlpha(0.5)} - withSliderMarks - /> - - {t('modelManager.modelMergeAlphaHelp')} - - - - - - {t('modelManager.interpolationType')} - - setModelMergeInterp(v)} - > - - {modelThree === 'none' ? ( - <> - - - {t('modelManager.weightedSum')} - - - - {t('modelManager.sigmoid')} - - - - {t('modelManager.inverseSigmoid')} - - - - ) : ( - - - - {t('modelManager.addDifference')} - - - - )} - - - - - - - - {t('modelManager.mergedModelSaveLocation')} - - - setModelMergeSaveLocType(v) - } - > - - - - {t('modelManager.invokeAIFolder')} - - - - - {t('modelManager.custom')} - - - - - - {modelMergeSaveLocType === 'custom' && ( - setModelMergeCustomSaveLoc(e.target.value)} - /> - )} - - - setModelMergeForce(e.target.checked)} - fontWeight="500" - /> - - - {t('modelManager.merge')} - - - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx deleted file mode 100644 index 440e5ad4db..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelManagerModal.tsx +++ /dev/null @@ -1,76 +0,0 @@ -import { - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - useDisclosure, -} from '@chakra-ui/react'; -import { cloneElement } from 'react'; - -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { useTranslation } from 'react-i18next'; - -import type { ReactElement } from 'react'; - -import CheckpointModelEdit from './CheckpointModelEdit'; -import DiffusersModelEdit from './DiffusersModelEdit'; -import ModelList from './ModelList'; - -type ModelManagerModalProps = { - children: ReactElement; -}; - -export default function ModelManagerModal({ - children, -}: ModelManagerModalProps) { - const { - isOpen: isModelManagerModalOpen, - onOpen: onModelManagerModalOpen, - onClose: onModelManagerModalClose, - } = useDisclosure(); - - const model_list = useAppSelector( - (state: RootState) => state.system.model_list - ); - - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - - const { t } = useTranslation(); - - return ( - <> - {cloneElement(children, { - onClick: onModelManagerModalOpen, - })} - - - - - {t('modelManager.modelManager')} - - - - {openModel && model_list[openModel]['format'] === 'diffusers' ? ( - - ) : ( - - )} - - - - - - - ); -} diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 8b098936b3..4eeee3e4c6 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -5,10 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { modelSelected } from 'features/parameters/store/generationSlice'; -import { forEach, isString } from 'lodash-es'; import { SelectItem } from '@mantine/core'; import { RootState } from 'app/store/store'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { forEach, isString } from 'lodash-es'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', @@ -23,18 +23,16 @@ const ModelSelect = () => { (state: RootState) => state.generation.model ); - const { data: pipelineModels, isLoading } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels, isLoading } = useGetMainModelsQuery(); const data = useMemo(() => { - if (!pipelineModels) { + if (!mainModels) { return []; } const data: SelectItem[] = []; - forEach(pipelineModels.entities, (model, id) => { + forEach(mainModels.entities, (model, id) => { if (!model) { return; } @@ -47,11 +45,11 @@ const ModelSelect = () => { }); return data; - }, [pipelineModels]); + }, [mainModels]); const selectedModel = useMemo( - () => pipelineModels?.entities[selectedModelId], - [pipelineModels?.entities, selectedModelId] + () => mainModels?.entities[selectedModelId], + [mainModels?.entities, selectedModelId] ); const handleChangeModel = useCallback( @@ -65,20 +63,18 @@ const ModelSelect = () => { ); useEffect(() => { - // If the selected model is not in the list of models, select the first one - // Handles first-run setting of models, and the user deleting the previously-selected model - if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + if (selectedModelId && mainModels?.ids.includes(selectedModelId)) { return; } - const firstModel = pipelineModels?.ids[0]; + const firstModel = mainModels?.ids[0]; if (!isString(firstModel)) { return; } handleChangeModel(firstModel); - }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + }, [handleChangeModel, mainModels?.ids, selectedModelId]); return isLoading ? ( { const { t } = useTranslation(); - const isModelManagerEnabled = - useFeatureStatus('modelManager').isFeatureEnabled; const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; @@ -37,20 +34,6 @@ const SiteHeader = () => { - {isModelManagerEnabled && ( - - } - /> - - )} - { const { t } = useTranslation(); - const isModelManagerEnabled = - useFeatureStatus('modelManager').isFeatureEnabled; const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; @@ -27,20 +24,6 @@ const SiteHeaderMenu = () => { flexDirection={{ base: 'column', xl: 'row' }} gap={{ base: 4, xl: 1 }} > - {isModelManagerEnabled && ( - - } - /> - - )} - { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.vae + ); + + const data = useMemo(() => { + if (!vaeModels) { + return []; + } + + const data: SelectItem[] = [ + { + value: 'auto', + label: 'Automatic', + group: 'Default', + }, + ]; + + forEach(vaeModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [vaeModels]); + + const selectedModel = useMemo( + () => vaeModels?.entities[selectedModelId], + [vaeModels?.entities, selectedModelId] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + dispatch(vaeSelected(v)); + }, + [dispatch] + ); + + useEffect(() => { + if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) { + return; + } + handleChangeModel('auto'); + }, [handleChangeModel, vaeModels?.ids, selectedModelId]); + + return ( + + ); +}; + +export default memo(VAESelect); diff --git a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx index 3e2c2153e6..af3eb72d8d 100644 --- a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx +++ b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx @@ -1,13 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; -import { useTranslation } from 'react-i18next'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { setShouldShowGallery } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import { MdPhotoLibrary } from 'react-icons/md'; import { activeTabNameSelector, uiSelector } from '../store/uiSelectors'; -import { memo } from 'react'; +import { NO_GALLERY_TABS } from './InvokeTabs'; const floatingGalleryButtonSelector = createSelector( [activeTabNameSelector, uiSelector], @@ -16,7 +17,9 @@ const floatingGalleryButtonSelector = createSelector( return { shouldPinGallery, - shouldShowGalleryButton: !shouldShowGallery, + shouldShowGalleryButton: NO_GALLERY_TABS.includes(activeTabName) + ? false + : !shouldShowGallery, }; }, { memoizeOptions: { resultEqualityCheck: isEqual } } diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 1b2ae81072..c618997f03 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -9,35 +9,35 @@ import { Tooltip, VisuallyHidden, } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; +import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; +import { configSelector } from 'features/system/store/configSelectors'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; -import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react'; +import { ResourceKey } from 'i18next'; +import { isEqual } from 'lodash-es'; +import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { FaCube, FaFont, FaImage } from 'react-icons/fa'; import { MdDeviceHub, MdGridOn } from 'react-icons/md'; +import { Panel, PanelGroup } from 'react-resizable-panels'; +import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize'; import { activeTabIndexSelector, activeTabNameSelector, } from '../store/uiSelectors'; -import { useTranslation } from 'react-i18next'; -import { ResourceKey } from 'i18next'; -import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { createSelector } from '@reduxjs/toolkit'; -import { configSelector } from 'features/system/store/configSelectors'; -import { isEqual } from 'lodash-es'; -import { Panel, PanelGroup } from 'react-resizable-panels'; -import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; +import ImageTab from './tabs/ImageToImage/ImageToImageTab'; +import ModelManagerTab from './tabs/ModelManager/ModelManagerTab'; +import NodesTab from './tabs/Nodes/NodesTab'; +import ResizeHandle from './tabs/ResizeHandle'; import TextToImageTab from './tabs/TextToImage/TextToImageTab'; import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab'; -import NodesTab from './tabs/Nodes/NodesTab'; -import { FaFont, FaImage, FaLayerGroup } from 'react-icons/fa'; -import ResizeHandle from './tabs/ResizeHandle'; -import ImageTab from './tabs/ImageToImage/ImageToImageTab'; -import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator'; -import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize'; -import BatchTab from './tabs/Batch/BatchTab'; export interface InvokeTabInfo { id: InvokeTabName; @@ -66,6 +66,11 @@ const tabs: InvokeTabInfo[] = [ icon: , content: , }, + { + id: 'modelManager', + icon: , + content: , + }, // { // id: 'batch', // icon: , @@ -87,6 +92,7 @@ const enabledTabsSelector = createSelector( const MIN_GALLERY_WIDTH = 300; const DEFAULT_GALLERY_PCT = 20; +export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager']; const InvokeTabs = () => { const activeTab = useAppSelector(activeTabIndexSelector); @@ -198,26 +204,28 @@ const InvokeTabs = () => { {tabPanels} - {shouldPinGallery && shouldShowGallery && ( - <> - - DEFAULT_GALLERY_PCT - ? galleryMinSizePct - : DEFAULT_GALLERY_PCT - } - minSize={galleryMinSizePct} - maxSize={50} - > - - - - )} + {shouldPinGallery && + shouldShowGallery && + !NO_GALLERY_TABS.includes(activeTabName) && ( + <> + + DEFAULT_GALLERY_PCT + ? galleryMinSizePct + : DEFAULT_GALLERY_PCT + } + minSize={galleryMinSizePct} + maxSize={50} + > + + + + )} ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index cdbec9b55d..5f5c7ad46b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -1,38 +1,45 @@ -import { memo } from 'react'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; -import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; -import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; -import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; -import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import IAICollapse from 'common/components/IAICollapse'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; +import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; +import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; +import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; +import { uiSelector } from 'features/ui/store/uiSelectors'; +import { memo } from 'react'; const selector = createSelector( [uiSelector, generationSelector], (ui, generation) => { const { shouldUseSliders } = ui; - const { shouldFitToWidthHeight } = generation; + const { shouldFitToWidthHeight, shouldRandomizeSeed } = generation; - return { shouldUseSliders, shouldFitToWidthHeight }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, shouldFitToWidthHeight, activeLabel }; }, defaultSelectorOptions ); const ImageToImageTabCoreParameters = () => { - const { shouldUseSliders, shouldFitToWidthHeight } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, shouldFitToWidthHeight, activeLabel } = + useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -58,7 +65,8 @@ const ImageToImageTabCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx index 4f04abffa1..32b71d6187 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx @@ -1,14 +1,15 @@ -import { memo } from 'react'; -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; -import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; const ImageToImageTabParameters = () => { return ( @@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx new file mode 100644 index 0000000000..8d675b17c8 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -0,0 +1,81 @@ +import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; +import i18n from 'i18n'; +import { ReactNode, memo } from 'react'; +import AddModelsPanel from './subpanels/AddModelsPanel'; +import MergeModelsPanel from './subpanels/MergeModelsPanel'; +import ModelManagerPanel from './subpanels/ModelManagerPanel'; + +type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels'; + +type ModelManagerTabInfo = { + id: ModelManagerTabName; + label: string; + content: ReactNode; +}; + +const modelManagerTabs: ModelManagerTabInfo[] = [ + { + id: 'modelManager', + label: i18n.t('modelManager.modelManager'), + content: , + }, + { + id: 'addModels', + label: i18n.t('modelManager.addModel'), + content: , + }, + { + id: 'mergeModels', + label: i18n.t('modelManager.mergeModels'), + content: , + }, +]; + +const renderTabsList = () => { + const modelManagerTabListsToRender: ReactNode[] = []; + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabListsToRender.push( + {modelManagerTab.label} + ); + }); + + return ( + + {modelManagerTabListsToRender} + + ); +}; + +const renderTabPanels = () => { + const modelManagerTabPanelsToRender: ReactNode[] = []; + modelManagerTabs.forEach((modelManagerTab) => { + modelManagerTabPanelsToRender.push( + {modelManagerTab.content} + ); + }); + + return {modelManagerTabPanelsToRender}; +}; + +const ModelManagerTab = () => { + return ( + + {renderTabsList()} + {renderTabPanels()} + + ); +}; + +export default memo(ModelManagerTab); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx new file mode 100644 index 0000000000..25f4adf4aa --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx @@ -0,0 +1,55 @@ +import { Divider, Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; +import { useTranslation } from 'react-i18next'; +import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel'; +import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel'; + +export default function AddModelsPanel() { + const addNewModelUIOption = useAppSelector( + (state: RootState) => state.ui.addNewModelUIOption + ); + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + return ( + + + dispatch(setAddNewModelUIOption('ckpt'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addCheckpointModel')} + + dispatch(setAddNewModelUIOption('diffusers'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addDiffuserModel')} + + + + + + {addNewModelUIOption == 'ckpt' && } + {addNewModelUIOption == 'diffusers' && } + + ); +} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx similarity index 99% rename from invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx index e6bd0b6ffb..75e2017bb8 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddCheckpointModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddCheckpointModel.tsx @@ -10,13 +10,11 @@ import { } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAIInput from 'common/components/IAIInput'; import IAINumberInput from 'common/components/IAINumberInput'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import React from 'react'; -import SearchModels from './SearchModels'; - // import { addNewModel } from 'app/socketio/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -24,12 +22,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { Field, Formik } from 'formik'; import { useTranslation } from 'react-i18next'; -import type { InvokeModelConfigProps } from 'app/types/invokeai'; import type { RootState } from 'app/store/store'; -import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; -import type { FieldInputProps, FormikProps } from 'formik'; +import type { InvokeModelConfigProps } from 'app/types/invokeai'; import IAIForm from 'common/components/IAIForm'; import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper'; +import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; +import type { FieldInputProps, FormikProps } from 'formik'; +import SearchModels from './SearchModels'; const MIN_MODEL_SIZE = 64; const MAX_MODEL_SIZE = 2048; diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx similarity index 99% rename from invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx index cb3af5f176..dd491828da 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/AddDiffusersModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx @@ -66,7 +66,7 @@ export default function AddDiffusersModel() { }; return ( - + value?.model_format === 'diffusers' + ); + + const [modelOne, setModelOne] = useState( + Object.keys(diffusersModels)[0] + ); + const [modelTwo, setModelTwo] = useState( + Object.keys(diffusersModels)[1] + ); + const [modelThree, setModelThree] = useState('none'); + + const [mergedModelName, setMergedModelName] = useState(''); + const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); + + const [modelMergeInterp, setModelMergeInterp] = useState< + 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' + >('weighted_sum'); + + const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< + 'root' | 'custom' + >('root'); + + const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = + useState(''); + + const [modelMergeForce, setModelMergeForce] = useState(false); + + const modelOneList = Object.keys(diffusersModels).filter( + (model) => model !== modelTwo && model !== modelThree + ); + + const modelTwoList = Object.keys(diffusersModels).filter( + (model) => model !== modelOne && model !== modelThree + ); + + const modelThreeList = [ + { key: t('modelManager.none'), value: 'none' }, + ...Object.keys(diffusersModels) + .filter((model) => model !== modelOne && model !== modelTwo) + .map((model) => ({ key: model, value: model })), + ]; + + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const mergeModelsHandler = () => { + let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + + const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { + models_to_merge: modelsToMerge, + merged_model_name: + mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + alpha: modelMergeAlpha, + interp: modelMergeInterp, + model_merge_save_path: + modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + force: modelMergeForce, + }; + + dispatch(mergeDiffusersModels(mergeModelsInfo)); + }; + + return ( + + + {t('modelManager.modelMergeHeaderHelp1')} + + {t('modelManager.modelMergeHeaderHelp2')} + + + + setModelOne(e.target.value)} + /> + setModelTwo(e.target.value)} + /> + { + if (e.target.value !== 'none') { + setModelThree(e.target.value); + setModelMergeInterp('add_difference'); + } else { + setModelThree('none'); + setModelMergeInterp('weighted_sum'); + } + }} + /> + + + setMergedModelName(e.target.value)} + /> + + + setModelMergeAlpha(v)} + withInput + withReset + handleReset={() => setModelMergeAlpha(0.5)} + withSliderMarks + /> + + {t('modelManager.modelMergeAlphaHelp')} + + + + + + {t('modelManager.interpolationType')} + + setModelMergeInterp(v)} + > + + {modelThree === 'none' ? ( + <> + + {t('modelManager.weightedSum')} + + + {t('modelManager.sigmoid')} + + + {t('modelManager.inverseSigmoid')} + + + ) : ( + + + {t('modelManager.addDifference')} + + + )} + + + + + + + + {t('modelManager.mergedModelSaveLocation')} + + setModelMergeSaveLocType(v)} + > + + + {t('modelManager.invokeAIFolder')} + + + + {t('modelManager.custom')} + + + + + + {modelMergeSaveLocType === 'custom' && ( + setModelMergeCustomSaveLoc(e.target.value)} + /> + )} + + + setModelMergeForce(e.target.checked)} + fontWeight="500" + /> + + + {t('modelManager.merge')} + + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx new file mode 100644 index 0000000000..b22a303571 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -0,0 +1,44 @@ +import { Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; + +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; +import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; +import ModelList from './ModelManagerPanel/ModelList'; + +export default function ModelManagerPanel() { + const { data: mainModels } = useGetMainModelsQuery(); + + const openModel = useAppSelector( + (state: RootState) => state.system.openModel + ); + + const renderModelEditTabs = () => { + if (!openModel || !mainModels) return; + + if (mainModels['entities'][openModel]['model_format'] === 'diffusers') { + return ( + + ); + } else { + return ( + + ); + } + }; + return ( + + + {renderModelEditTabs()} + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx new file mode 100644 index 0000000000..0d5d21175a --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -0,0 +1,141 @@ +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; + +import { Divider, Flex, Text } from '@chakra-ui/react'; + +// import { addNewModel } from 'app/socketio/actions'; +import { useForm } from '@mantine/form'; +import { useTranslation } from 'react-i18next'; + +import type { RootState } from 'app/store/store'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { S } from 'services/api/types'; +import ModelConvert from './ModelConvert'; + +const baseModelSelectData = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, +]; + +const variantSelectData = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +export type CheckpointModel = + | S<'StableDiffusion1ModelCheckpointConfig'> + | S<'StableDiffusion2ModelCheckpointConfig'>; + +type CheckpointModelEditProps = { + modelToEdit: string; + retrievedModel: CheckpointModel; +}; + +export default function CheckpointModelEdit(props: CheckpointModelEditProps) { + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const { modelToEdit, retrievedModel } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const checkpointEditForm = useForm({ + initialValues: { + name: retrievedModel.name, + base_model: retrievedModel.base_model, + type: 'main', + path: retrievedModel.path, + description: retrievedModel.description, + model_format: 'checkpoint', + vae: retrievedModel.vae, + config: retrievedModel.config, + variant: retrievedModel.variant, + }, + }); + + const editModelFormSubmitHandler = (values) => { + console.log(values); + }; + + return modelToEdit ? ( + + + + + {retrievedModel.name} + + + {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + + + + + + + +
+ editModelFormSubmitHandler(values) + )} + > + + + + + + + + + + {t('modelManager.updateModel')} + + +
+
+
+ ) : ( + + Pick A Model To Edit + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx new file mode 100644 index 0000000000..6a7b4b3140 --- /dev/null +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -0,0 +1,125 @@ +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; + +import { Divider, Flex, Text } from '@chakra-ui/react'; + +// import { addNewModel } from 'app/socketio/actions'; +import { useTranslation } from 'react-i18next'; + +import { useForm } from '@mantine/form'; +import type { RootState } from 'app/store/store'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { S } from 'services/api/types'; + +type DiffusersModel = + | S<'StableDiffusion1ModelDiffusersConfig'> + | S<'StableDiffusion2ModelDiffusersConfig'>; + +type DiffusersModelEditProps = { + modelToEdit: string; + retrievedModel: DiffusersModel; +}; + +const baseModelSelectData = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, +]; + +const variantSelectData = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +export default function DiffusersModelEdit(props: DiffusersModelEditProps) { + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + const { retrievedModel, modelToEdit } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const diffusersEditForm = useForm({ + initialValues: { + name: retrievedModel.name, + base_model: retrievedModel.base_model, + type: 'main', + path: retrievedModel.path, + description: retrievedModel.description, + model_format: 'diffusers', + vae: retrievedModel.vae, + variant: retrievedModel.variant, + }, + }); + + const editModelFormSubmitHandler = (values) => { + console.log(values); + }; + + return modelToEdit ? ( + + + + {retrievedModel.name} + + + {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + + + + +
+ editModelFormSubmitHandler(values) + )} + > + + + + + + + + + {t('modelManager.updateModel')} + + +
+
+ ) : ( + + Pick A Model To Edit + + ); +} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx similarity index 84% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 820ad546b3..9f571c2fff 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -4,42 +4,28 @@ import { Radio, RadioGroup, Text, - UnorderedList, Tooltip, + UnorderedList, } from '@chakra-ui/react'; // import { convertToDiffusers } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; -import { useState, useEffect } from 'react'; +import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { CheckpointModel } from './CheckpointModelEdit'; interface ModelConvertProps { - model: string; + model: CheckpointModel; } export default function ModelConvert(props: ModelConvertProps) { const { model } = props; - const model_list = useAppSelector( - (state: RootState) => state.system.model_list - ); - - const retrievedModel = model_list[model]; - const dispatch = useAppDispatch(); const { t } = useTranslation(); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const isConnected = useAppSelector( - (state: RootState) => state.system.isConnected - ); - const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); @@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) { return ( 🧨 {t('modelManager.convertToDiffusers')} diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx similarity index 75% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 4ef311e1d4..eb05e70357 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,36 +1,14 @@ -import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react'; -import IAIInput from 'common/components/IAIInput'; +import { Box, Flex, Spinner, Text } from '@chakra-ui/react'; import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; -import AddModel from './AddModel'; import ModelListItem from './ModelListItem'; -import MergeModels from './MergeModels'; -import { useAppSelector } from 'app/store/storeHooks'; import { useTranslation } from 'react-i18next'; -import { createSelector } from '@reduxjs/toolkit'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import type { SystemState } from 'features/system/store/systemSlice'; -import { isEqual, map } from 'lodash-es'; - -import React, { useMemo, useState, useTransition } from 'react'; import type { ChangeEvent, ReactNode } from 'react'; - -const modelListSelector = createSelector( - systemSelector, - (system: SystemState) => { - const models = map(system.model_list, (model, key) => { - return { name: key, ...model }; - }); - return models; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import React, { useMemo, useState, useTransition } from 'react'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; function ModelFilterButton({ label, @@ -58,7 +36,7 @@ function ModelFilterButton({ } const ModelList = () => { - const models = useAppSelector(modelListSelector); + const { data: mainModels } = useGetMainModelsQuery(); const [renderModelList, setRenderModelList] = React.useState(false); @@ -90,43 +68,49 @@ const ModelList = () => { const filteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = []; - models.forEach((model, i) => { - if (model.name.toLowerCase().includes(searchText.toLowerCase())) { + if (!mainModels) return; + + const modelList = mainModels.entities; + + Object.keys(modelList).forEach((model, i) => { + if ( + modelList[model].name.toLowerCase().includes(searchText.toLowerCase()) + ) { filteredModelListItemsToRender.push( ); - if (model.format === isSelectedFilter) { + if (modelList[model]?.model_format === isSelectedFilter) { localFilteredModelListItemsToRender.push( ); } } - if (model.format !== 'diffusers') { + if (modelList[model]?.model_format !== 'diffusers') { ckptModelListItemsToRender.push( ); } else { diffusersModelListItemsToRender.push( ); } @@ -142,6 +126,23 @@ const ModelList = () => { {isSelectedFilter === 'all' && ( <> + + + {t('modelManager.diffusersModels')} + + {diffusersModelListItemsToRender} + { {ckptModelListItemsToRender} - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - )} - {isSelectedFilter === 'ckpt' && ( - - {ckptModelListItemsToRender} - - )} - {isSelectedFilter === 'diffusers' && ( {diffusersModelListItemsToRender} )} + + {isSelectedFilter === 'ckpt' && ( + + {ckptModelListItemsToRender} + + )} ); - }, [models, searchText, t, isSelectedFilter]); + }, [mainModels, searchText, t, isSelectedFilter]); return ( - - {t('modelManager.availableModels')} - - - - - { { onClick={() => setIsSelectedFilter('all')} isActive={isSelectedFilter === 'all'} /> - setIsSelectedFilter('ckpt')} - isActive={isSelectedFilter === 'ckpt'} - /> setIsSelectedFilter('diffusers')} isActive={isSelectedFilter === 'diffusers'} /> + setIsSelectedFilter('ckpt')} + isActive={isSelectedFilter === 'ckpt'} + /> {renderModelList ? ( diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx similarity index 75% rename from invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx rename to invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index aa9f87816c..ab5fddd5ea 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,6 +1,6 @@ import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; -import { ModelStatus } from 'app/types/invokeai'; +import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; + // import { deleteModel, requestModelChange } from 'app/socketio/actions'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; type ModelListItemProps = { + modelKey: string; name: string; - status: ModelStatus; - description: string; + description: string | undefined; }; export default function ModelListItem(props: ModelListItemProps) { @@ -28,39 +28,24 @@ export default function ModelListItem(props: ModelListItemProps) { const dispatch = useAppDispatch(); - const { name, status, description } = props; - - const handleChangeModel = () => { - dispatch(requestModelChange(name)); - }; + const { modelKey, name, description } = props; const openModelHandler = () => { - dispatch(setOpenModel(name)); + dispatch(setOpenModel(modelKey)); }; const handleModelDelete = () => { - dispatch(deleteModel(name)); + dispatch(deleteModel(modelKey)); dispatch(setOpenModel(null)); }; - const statusTextColor = () => { - switch (status) { - case 'active': - return 'ok.500'; - case 'cached': - return 'warning.500'; - case 'not loaded': - return 'inherit'; - } - }; - return ( - {status} - - } size="sm" diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx index 07297bda31..9211e095ba 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx @@ -1,34 +1,41 @@ -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; -import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { memo } from 'react'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; import IAICollapse from 'common/components/IAICollapse'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const TextToImageTabCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -54,7 +61,8 @@ const TextToImageTabCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx index bcc6c91ae6..6291b69a8e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx @@ -1,15 +1,16 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; const TextToImageTabParameters = () => { return ( @@ -18,6 +19,7 @@ const TextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index 42e19eb096..330cd8b31e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -1,35 +1,42 @@ -import { memo } from 'react'; -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { uiSelector } from 'features/ui/store/uiSelectors'; +import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; -import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; -import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; -import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; -import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth'; -import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight'; -import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import IAICollapse from 'common/components/IAICollapse'; +import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE'; +import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; +import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { memo } from 'react'; const selector = createSelector( - uiSelector, - (ui) => { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const UnifiedCanvasCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { > {shouldUseSliders ? ( <> - + @@ -55,7 +62,8 @@ const UnifiedCanvasCoreParameters = () => { - + + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index 061ebb962e..63ed4cc1cf 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -1,14 +1,15 @@ -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; -import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; -import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; const UnifiedCanvasParameters = () => { return ( @@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/store/tabMap.ts b/invokeai/frontend/web/src/features/ui/store/tabMap.ts index 4f683c95cb..0cae8eac43 100644 --- a/invokeai/frontend/web/src/features/ui/store/tabMap.ts +++ b/invokeai/frontend/web/src/features/ui/store/tabMap.ts @@ -1,12 +1,10 @@ export const tabMap = [ 'txt2img', 'img2img', - // 'generate', 'unifiedCanvas', 'nodes', + 'modelManager', 'batch', - // 'postprocessing', - // 'training', ] as const; export type InvokeTabName = (typeof tabMap)[number]; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 39e4e46d3b..a9a914f0f2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,37 +1,85 @@ -import { ModelsList } from 'services/api/types'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; -import { keyBy } from 'lodash-es'; +import { cloneDeep } from 'lodash-es'; +import { + AnyModelConfig, + ControlNetModelConfig, + LoRAModelConfig, + MainModelConfig, + TextualInversionModelConfig, + VaeModelConfig, +} from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; -type ModelConfig = ModelsList['models'][number]; +export type MainModelConfigEntity = MainModelConfig & { id: string }; -type ListModelsArg = NonNullable< - paths['/api/v1/models/']['get']['parameters']['query'] ->; +export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; -const modelsAdapter = createEntityAdapter({ - selectId: (model) => getModelId(model), +export type ControlNetModelConfigEntity = ControlNetModelConfig & { + id: string; +}; + +export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { + id: string; +}; + +export type VaeModelConfigEntity = VaeModelConfig & { id: string }; + +type AnyModelConfigEntity = + | MainModelConfigEntity + | LoRAModelConfigEntity + | ControlNetModelConfigEntity + | TextualInversionModelConfigEntity + | VaeModelConfigEntity; + +const mainModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const loraModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const controlNetModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const textualInversionModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const vaeModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); -const getModelId = ({ base_model, type, name }: ModelConfig) => +export const getModelId = ({ base_model, type, name }: AnyModelConfig) => `${base_model}/${type}/${name}`; +const createModelEntities = ( + models: AnyModelConfig[] +): T[] => { + const entityArray: T[] = []; + models.forEach((model) => { + const entity = { + ...cloneDeep(model), + id: getModelId(model), + } as T; + entityArray.push(entity); + }); + return entityArray; +}; + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - listModels: build.query, ListModelsArg>({ - query: (arg) => ({ url: 'models/', params: arg }), + getMainModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'main' } }), providesTags: (result, error, arg) => { - // any list of boards - const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [ + { id: 'MainModel', type: LIST_TAG }, + ]; if (result) { - // and individual tags for each board tags.push( ...result.ids.map((id) => ({ - type: 'Model' as const, + type: 'MainModel' as const, id, })) ); @@ -39,14 +87,161 @@ export const modelsApi = api.injectEndpoints({ return tags; }, - transformResponse: (response: ModelsList, meta, arg) => { - return modelsAdapter.setAll( - modelsAdapter.getInitialState(), - keyBy(response.models, getModelId) + transformResponse: ( + response: { models: MainModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return mainModelsAdapter.setAll( + mainModelsAdapter.getInitialState(), + entities + ); + }, + }), + getLoRAModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'lora' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'LoRAModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'LoRAModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: LoRAModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return loraModelsAdapter.setAll( + loraModelsAdapter.getInitialState(), + entities + ); + }, + }), + getControlNetModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'ControlNetModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'ControlNetModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: ControlNetModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return controlNetModelsAdapter.setAll( + controlNetModelsAdapter.getInitialState(), + entities + ); + }, + }), + getVaeModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'vae' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'VaeModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'VaeModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: VaeModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return vaeModelsAdapter.setAll( + vaeModelsAdapter.getInitialState(), + entities + ); + }, + }), + getTextualInversionModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'TextualInversionModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'TextualInversionModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: TextualInversionModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return textualInversionModelsAdapter.setAll( + textualInversionModelsAdapter.getInitialState(), + entities ); }, }), }), }); -export const { useListModelsQuery } = modelsApi; +export const { + useGetMainModelsQuery, + useGetControlNetModelsQuery, + useGetLoRAModelsQuery, + useGetTextualInversionModelsQuery, + useGetVaeModelsQuery, +} = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index e542cd4ba2..d7e50d004e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -76,9 +76,16 @@ export type paths = { */ get: operations["list_models"]; /** - * Import Model + * Update Model * @description Add Model */ + post: operations["update_model"]; + }; + "/api/v1/models/import": { + /** + * Import Model + * @description Add a model using its local path, repo_id, or remote URL + */ post: operations["import_model"]; }; "/api/v1/models/{model_name}": { @@ -227,6 +234,23 @@ export type components = { */ b?: number; }; + /** AddModelResult */ + AddModelResult: { + /** + * Name + * @description The name of the model after import + */ + name: string; + /** @description The type of model */ + model_type: components["schemas"]["ModelType"]; + /** @description The base model */ + base_model: components["schemas"]["BaseModelType"]; + /** + * Config + * @description The configuration of the model + */ + config: components["schemas"]["ModelConfigBase"]; + }; /** * BaseModelType * @description An enumeration. @@ -1030,7 +1054,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; + [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; }; /** * Edges @@ -1073,7 +1097,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; + [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; }; /** * Errors @@ -1975,19 +1999,23 @@ export type components = { */ thumbnail_url: string; }; - /** ImportModelRequest */ - ImportModelRequest: { + /** ImportModelResponse */ + ImportModelResponse: { /** * Name - * @description A model path, repo_id or URL to import + * @description The name of the imported model */ name: string; /** - * Prediction Type - * @description Prediction type for SDv2 checkpoint files - * @enum {string} + * Info + * @description The model info */ - prediction_type?: "epsilon" | "v_prediction" | "sample"; + info: components["schemas"]["AddModelResult"]; + /** + * Status + * @description The status of the API response + */ + status: string; }; /** * InfillColorInvocation @@ -2662,6 +2690,19 @@ export type components = { model_format: components["schemas"]["LoRAModelFormat"]; error?: components["schemas"]["ModelError"]; }; + /** + * LoRAModelField + * @description LoRA model field + */ + LoRAModelField: { + /** + * Model Name + * @description Name of the LoRA model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** * LoRAModelFormat * @description An enumeration. @@ -2738,10 +2779,10 @@ export type components = { */ type?: "lora_loader"; /** - * Lora Name + * Lora * @description Lora model name */ - lora_name: string; + lora?: components["schemas"]["LoRAModelField"]; /** * Weight * @description With what weight to apply lora @@ -2781,6 +2822,47 @@ export type components = { */ clip?: components["schemas"]["ClipField"]; }; + /** + * MainModelField + * @description Main model field + */ + MainModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; + /** + * MainModelLoaderInvocation + * @description Loads a main model, outputting its submodels. + */ + MainModelLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default main_model_loader + * @enum {string} + */ + type?: "main_model_loader"; + /** + * Model + * @description The model to load + */ + model: components["schemas"]["MainModelField"]; + }; /** * MaskFromAlphaInvocation * @description Extracts the alpha channel of an image as a mask. @@ -2974,6 +3056,16 @@ export type components = { */ thr_d?: number; }; + /** ModelConfigBase */ + ModelConfigBase: { + /** Path */ + path: string; + /** Description */ + description?: string; + /** Model Format */ + model_format?: string; + error?: components["schemas"]["ModelError"]; + }; /** * ModelError * @description An enumeration. @@ -3036,7 +3128,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -3425,47 +3517,6 @@ export type components = { */ scribble?: boolean; }; - /** - * PipelineModelField - * @description Pipeline model field - */ - PipelineModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["BaseModelType"]; - }; - /** - * PipelineModelLoaderInvocation - * @description Loads a pipeline model, outputting its submodels. - */ - PipelineModelLoaderInvocation: { - /** - * Id - * @description The id of this node. Must be unique among all nodes. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this node is an intermediate node. - * @default false - */ - is_intermediate?: boolean; - /** - * Type - * @default pipeline_model_loader - * @enum {string} - */ - type?: "pipeline_model_loader"; - /** - * Model - * @description The model to load - */ - model: components["schemas"]["PipelineModelField"]; - }; /** * PromptCollectionOutput * @description Base class for invocations that output a collection of prompts @@ -4266,6 +4317,19 @@ export type components = { */ level?: 2 | 4; }; + /** + * VAEModelField + * @description Vae model field + */ + VAEModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** VaeField */ VaeField: { /** @@ -4274,6 +4338,51 @@ export type components = { */ vae: components["schemas"]["ModelInfo"]; }; + /** + * VaeLoaderInvocation + * @description Loads a VAE model, outputting a VaeLoaderOutput + */ + VaeLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default vae_loader + * @enum {string} + */ + type?: "vae_loader"; + /** + * Vae Model + * @description The VAE to load + */ + vae_model: components["schemas"]["VAEModelField"]; + }; + /** + * VaeLoaderOutput + * @description Model loader output + */ + VaeLoaderOutput: { + /** + * Type + * @default vae_loader_output + * @enum {string} + */ + type?: "vae_loader_output"; + /** + * Vae + * @description Vae model + */ + vae?: components["schemas"]["VaeField"]; + }; /** VaeModelConfig */ VaeModelConfig: { /** Name */ @@ -4352,18 +4461,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -4474,7 +4583,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -4511,7 +4620,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -4731,13 +4840,13 @@ export type operations = { }; }; /** - * Import Model + * Update Model * @description Add Model */ - import_model: { + update_model: { requestBody: { content: { - "application/json": components["schemas"]["ImportModelRequest"]; + "application/json": components["schemas"]["CreateModelRequest"]; }; }; responses: { @@ -4755,6 +4864,36 @@ export type operations = { }; }; }; + /** + * Import Model + * @description Add a model using its local path, repo_id, or remote URL + */ + import_model: { + parameters: { + query: { + /** @description A model path, repo_id or URL to import */ + name: string; + /** @description Prediction type for SDv2 checkpoint files */ + prediction_type?: "v_prediction" | "epsilon" | "sample"; + }; + }; + responses: { + /** @description The model imported successfully */ + 201: { + content: { + "application/json": components["schemas"]["ImportModelResponse"]; + }; + }; + /** @description The model could not be found */ + 404: never; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Delete Model * @description Delete Model diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 12c072509b..3a0bdb71a7 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -4,90 +4,156 @@ import { components } from './schema'; type schemas = components['schemas']; /** - * Extracts the schema type from the schema. + * Marks the `type` property as required. Use for nodes. */ -type S = components['schemas'][T]; - -/** - * Extracts the node type from the schema. - * Also flags the `type` property as required. - */ -type N = O.Required< - components['schemas'][T], - 'type' ->; +type TypeReq = O.Required; // Images -export type ImageDTO = S<'ImageDTO'>; -export type BoardDTO = S<'BoardDTO'>; -export type BoardChanges = S<'BoardChanges'>; -export type ImageChanges = S<'ImageRecordChanges'>; -export type ImageCategory = S<'ImageCategory'>; -export type ResourceOrigin = S<'ResourceOrigin'>; -export type ImageField = S<'ImageField'>; +export type ImageDTO = components['schemas']['ImageDTO']; +export type BoardDTO = components['schemas']['BoardDTO']; +export type BoardChanges = components['schemas']['BoardChanges']; +export type ImageChanges = components['schemas']['ImageRecordChanges']; +export type ImageCategory = components['schemas']['ImageCategory']; +export type ResourceOrigin = components['schemas']['ResourceOrigin']; +export type ImageField = components['schemas']['ImageField']; export type OffsetPaginatedResults_BoardDTO_ = - S<'OffsetPaginatedResults_BoardDTO_'>; + components['schemas']['OffsetPaginatedResults_BoardDTO_']; export type OffsetPaginatedResults_ImageDTO_ = - S<'OffsetPaginatedResults_ImageDTO_'>; + components['schemas']['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = S<'ModelType'>; -export type BaseModelType = S<'BaseModelType'>; -export type PipelineModelField = S<'PipelineModelField'>; -export type ModelsList = S<'ModelsList'>; +export type ModelType = components['schemas']['ModelType']; +export type BaseModelType = components['schemas']['BaseModelType']; +export type MainModelField = components['schemas']['MainModelField']; +export type VAEModelField = components['schemas']['VAEModelField']; +export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ModelsList = components['schemas']['ModelsList']; + +// Model Configs +export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; +export type VaeModelConfig = components['schemas']['VaeModelConfig']; +export type ControlNetModelConfig = + components['schemas']['ControlNetModelConfig']; +export type TextualInversionModelConfig = + components['schemas']['TextualInversionModelConfig']; +export type MainModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type AnyModelConfig = + | LoRAModelConfig + | VaeModelConfig + | ControlNetModelConfig + | TextualInversionModelConfig + | MainModelConfig; // Graphs -export type Graph = S<'Graph'>; -export type Edge = S<'Edge'>; -export type GraphExecutionState = S<'GraphExecutionState'>; +export type Graph = components['schemas']['Graph']; +export type Edge = components['schemas']['Edge']; +export type GraphExecutionState = components['schemas']['GraphExecutionState']; // General nodes -export type CollectInvocation = N<'CollectInvocation'>; -export type IterateInvocation = N<'IterateInvocation'>; -export type RangeInvocation = N<'RangeInvocation'>; -export type RandomRangeInvocation = N<'RandomRangeInvocation'>; -export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>; -export type InpaintInvocation = N<'InpaintInvocation'>; -export type ImageResizeInvocation = N<'ImageResizeInvocation'>; -export type RandomIntInvocation = N<'RandomIntInvocation'>; -export type CompelInvocation = N<'CompelInvocation'>; -export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>; -export type NoiseInvocation = N<'NoiseInvocation'>; -export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>; -export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; -export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; -export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; -export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>; -export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; +export type CollectInvocation = TypeReq< + components['schemas']['CollectInvocation'] +>; +export type IterateInvocation = TypeReq< + components['schemas']['IterateInvocation'] +>; +export type RangeInvocation = TypeReq; +export type RandomRangeInvocation = TypeReq< + components['schemas']['RandomRangeInvocation'] +>; +export type RangeOfSizeInvocation = TypeReq< + components['schemas']['RangeOfSizeInvocation'] +>; +export type InpaintInvocation = TypeReq< + components['schemas']['InpaintInvocation'] +>; +export type ImageResizeInvocation = TypeReq< + components['schemas']['ImageResizeInvocation'] +>; +export type RandomIntInvocation = TypeReq< + components['schemas']['RandomIntInvocation'] +>; +export type CompelInvocation = TypeReq< + components['schemas']['CompelInvocation'] +>; +export type DynamicPromptInvocation = TypeReq< + components['schemas']['DynamicPromptInvocation'] +>; +export type NoiseInvocation = TypeReq; +export type TextToLatentsInvocation = TypeReq< + components['schemas']['TextToLatentsInvocation'] +>; +export type LatentsToLatentsInvocation = TypeReq< + components['schemas']['LatentsToLatentsInvocation'] +>; +export type ImageToLatentsInvocation = TypeReq< + components['schemas']['ImageToLatentsInvocation'] +>; +export type LatentsToImageInvocation = TypeReq< + components['schemas']['LatentsToImageInvocation'] +>; +export type ImageCollectionInvocation = TypeReq< + components['schemas']['ImageCollectionInvocation'] +>; +export type MainModelLoaderInvocation = TypeReq< + components['schemas']['MainModelLoaderInvocation'] +>; +export type LoraLoaderInvocation = TypeReq< + components['schemas']['LoraLoaderInvocation'] +>; // ControlNet Nodes -export type ControlNetInvocation = N<'ControlNetInvocation'>; -export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>; -export type ContentShuffleImageProcessorInvocation = - N<'ContentShuffleImageProcessorInvocation'>; -export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>; -export type LineartAnimeImageProcessorInvocation = - N<'LineartAnimeImageProcessorInvocation'>; -export type LineartImageProcessorInvocation = - N<'LineartImageProcessorInvocation'>; -export type MediapipeFaceProcessorInvocation = - N<'MediapipeFaceProcessorInvocation'>; -export type MidasDepthImageProcessorInvocation = - N<'MidasDepthImageProcessorInvocation'>; -export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>; -export type NormalbaeImageProcessorInvocation = - N<'NormalbaeImageProcessorInvocation'>; -export type OpenposeImageProcessorInvocation = - N<'OpenposeImageProcessorInvocation'>; -export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>; -export type ZoeDepthImageProcessorInvocation = - N<'ZoeDepthImageProcessorInvocation'>; +export type ControlNetInvocation = TypeReq< + components['schemas']['ControlNetInvocation'] +>; +export type CannyImageProcessorInvocation = TypeReq< + components['schemas']['CannyImageProcessorInvocation'] +>; +export type ContentShuffleImageProcessorInvocation = TypeReq< + components['schemas']['ContentShuffleImageProcessorInvocation'] +>; +export type HedImageProcessorInvocation = TypeReq< + components['schemas']['HedImageProcessorInvocation'] +>; +export type LineartAnimeImageProcessorInvocation = TypeReq< + components['schemas']['LineartAnimeImageProcessorInvocation'] +>; +export type LineartImageProcessorInvocation = TypeReq< + components['schemas']['LineartImageProcessorInvocation'] +>; +export type MediapipeFaceProcessorInvocation = TypeReq< + components['schemas']['MediapipeFaceProcessorInvocation'] +>; +export type MidasDepthImageProcessorInvocation = TypeReq< + components['schemas']['MidasDepthImageProcessorInvocation'] +>; +export type MlsdImageProcessorInvocation = TypeReq< + components['schemas']['MlsdImageProcessorInvocation'] +>; +export type NormalbaeImageProcessorInvocation = TypeReq< + components['schemas']['NormalbaeImageProcessorInvocation'] +>; +export type OpenposeImageProcessorInvocation = TypeReq< + components['schemas']['OpenposeImageProcessorInvocation'] +>; +export type PidiImageProcessorInvocation = TypeReq< + components['schemas']['PidiImageProcessorInvocation'] +>; +export type ZoeDepthImageProcessorInvocation = TypeReq< + components['schemas']['ZoeDepthImageProcessorInvocation'] +>; // Node Outputs -export type ImageOutput = S<'ImageOutput'>; -export type MaskOutput = S<'MaskOutput'>; -export type PromptOutput = S<'PromptOutput'>; -export type IterateInvocationOutput = S<'IterateInvocationOutput'>; -export type CollectInvocationOutput = S<'CollectInvocationOutput'>; -export type LatentsOutput = S<'LatentsOutput'>; -export type GraphInvocationOutput = S<'GraphInvocationOutput'>; +export type ImageOutput = components['schemas']['ImageOutput']; +export type MaskOutput = components['schemas']['MaskOutput']; +export type PromptOutput = components['schemas']['PromptOutput']; +export type IterateInvocationOutput = + components['schemas']['IterateInvocationOutput']; +export type CollectInvocationOutput = + components['schemas']['CollectInvocationOutput']; +export type LatentsOutput = components['schemas']['LatentsOutput']; +export type GraphInvocationOutput = + components['schemas']['GraphInvocationOutput']; diff --git a/invokeai/frontend/web/src/theme/components/textarea.ts b/invokeai/frontend/web/src/theme/components/textarea.ts index 85e6e37d3f..b737cf5e57 100644 --- a/invokeai/frontend/web/src/theme/components/textarea.ts +++ b/invokeai/frontend/web/src/theme/components/textarea.ts @@ -1,7 +1,28 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { getInputOutlineStyles } from '../util/getInputOutlineStyles'; -const invokeAI = defineStyle((props) => getInputOutlineStyles(props)); +const invokeAI = defineStyle((props) => ({ + ...getInputOutlineStyles(props), + '::-webkit-scrollbar': { + display: 'initial', + }, + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-50) 0%, + var(--invokeai-colors-base-50) 70%, + var(--invokeai-colors-base-200) 70%, + var(--invokeai-colors-base-200) 100%)`, + }, + _dark: { + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-900) 0%, + var(--invokeai-colors-base-900) 70%, + var(--invokeai-colors-base-800) 70%, + var(--invokeai-colors-base-800) 100%)`, + }, + }, +})); export const textareaTheme = defineStyleConfig({ variants: { diff --git a/invokeai/frontend/web/stats.html b/invokeai/frontend/web/stats.html index 1583812f37..7c7df1671a 100644 --- a/invokeai/frontend/web/stats.html +++ b/invokeai/frontend/web/stats.html @@ -4818,7 +4818,7 @@ var drawChart = (function (exports) {