Merge branch 'main' into lstein/model-manager-router-api

This commit is contained in:
Lincoln Stein 2023-07-05 23:16:43 -04:00 committed by GitHub
commit 8f5fcb188c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
102 changed files with 2138 additions and 907 deletions

2
.gitignore vendored
View File

@ -201,8 +201,6 @@ checkpoints
# If it's a Mac # If it's a Mac
.DS_Store .DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore # Let the frontend manage its own gitignore
!invokeai/frontend/web/* !invokeai/frontend/web/*

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseModel, Field from pydantic import BaseConfig, BaseModel, Field
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -65,7 +66,12 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_invocations_map(cls): def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary # Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)
@classmethod @classmethod
def get_output_type(cls): def get_output_type(cls):
@ -76,10 +82,10 @@ class BaseInvocation(ABC, BaseModel):
"""Invoke with provided context and return outputs.""" """Invoke with provided context and return outputs."""
pass pass
#fmt: off # fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on # fmt: on
# TODO: figure out a better way to provide these hints # TODO: figure out a better way to provide these hints
@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
"model", "model",
"control", "control",
"image_collection", "image_collection",
"vae_model",
"lora_model",
], ],
] ]
tags: List[str] tags: List[str]
title: str title: str
class CustomisedSchemaExtra(TypedDict): class CustomisedSchemaExtra(TypedDict):
ui: UIConfig ui: UIConfig
class InvocationConfig(BaseModel.Config): class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations. """Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs. Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -1,28 +1,28 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import torch import torch
from compel import Compel
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from ...backend.model_management.models import ModelNotFoundException
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from compel import Compel from ...backend.util.devices import torch_dtype
from compel.prompt_parser import ( from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
Blend, InvocationConfig, InvocationContext)
CrossAttentionControlSubstitute, from .model import ClipField
FlattenedPrompt,
Fragment, Conjunction,
)
class ConditioningField(BaseModel): class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config: class Config:
schema_extra = {"required": ["conditioning_name"]} schema_extra = {"required": ["conditioning_name"]}
@ -52,84 +52,92 @@ class CompelInvocation(BaseInvocation):
"title": "Prompt (Compel)", "title": "Prompt (Compel)",
"tags": ["prompt", "compel"], "tags": ["prompt", "compel"],
"type_hints": { "type_hints": {
"model": "model" "model": "model"
} }
}, },
} }
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder:
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
ti_list = [] #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ ti_list = []
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
compel = Compel( try:
tokenizer=tokenizer, ti_list.append(
text_encoder=text_encoder, context.services.model_manager.get_model(
textual_inversion_manager=ti_manager, model_name=name,
dtype_for_device_getter=torch_dtype, base_model=self.clip.text_encoder.base_model,
truncate_long_prompts=True, # TODO: model_type=ModelType.TextualInversion,
).context.model
) )
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
conjunction = Compel.parse_prompt_string(self.prompt) with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
text_encoder_info as text_encoder:
if context.services.configuration.log_tokenization: compel = Compel(
log_tokenization_for_prompt_object(prompt, tokenizer) tokenizer=tokenizer,
text_encoder=text_encoder,
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
# TODO: long prompt support truncate_long_prompts=True, # TODO:
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
) )
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
) -> int: truncate_if_too_long=False) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max(
@ -148,13 +156,13 @@ def get_max_token_count(
) )
else: else:
return len( return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) get_tokens_for_prompt_object(
) tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object( def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]: ) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError( raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children" "Blend is not supported here - you need to get tokens for each of its .children"
@ -183,7 +191,7 @@ def log_tokenization_for_conjunction(
): ):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts)>1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
@ -238,7 +246,8 @@ def log_tokenization_for_prompt_object(
) )
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -4,18 +4,17 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents") latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off #fmt: off
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
def build_latents_output(latents_name: str, latents: torch.Tensor): def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput( return LatentsOutput(
latents=LatentsField(latents_name=latents_name), latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8, width=latents.size()[3] * 8,
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
@ -70,14 +72,17 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents"],
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number"
} }
}, },
} }
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState self, context: InvocationContext, source_node_id: str,
) -> None: intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: def get_conditioning_data(
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) self, context: InvocationContext, scheduler) -> ConditioningData:
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
extra=extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold, threshold=0.0, # threshold,
warmup=0.2,#warmup, warmup=0.2, # warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None # v_symmetry_time_pct,
), ),
) )
@ -164,20 +172,21 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, #ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0), generator=torch.Generator(device=uc.device).manual_seed(0),
) )
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
#configure_model_padding( # configure_model_padding(
# unet, # unet,
# self.seamless, # self.seamless,
# self.seamless_axes, # self.seamless_axes,
#) # )
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
@ -188,7 +197,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.config = FakeVae.FakeVaeConfig() self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline( return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh... vae=FakeVae(), # TODO: oh...
text_encoder=None, text_encoder=None,
tokenizer=None, tokenizer=None,
unet=unet, unet=unet,
@ -202,7 +211,8 @@ class TextToLatentsInvocation(BaseInvocation):
def prep_control_data( def prep_control_data(
self, self,
context: InvocationContext, context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device # really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
print("Using HF model subfolders") print("Using HF model subfolders")
print(" control_name: ", control_name) print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder) print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name, control_model = ControlNetModel.from_pretrained(
subfolder=control_subfolder, control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(
model.device)
else: else:
control_model = ControlNetModel.from_pretrained(control_info.control_model, control_model = ControlNetModel.from_pretrained(
torch_dtype=model.unet.dtype).to(model.device) control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name) input_image = context.services.images.get_pil_image(
control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -263,29 +275,40 @@ class TextToLatentsInvocation(BaseInvocation):
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(
image_tensor=control_image, model=control_model, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,)
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) def _lora_loader():
with unet_info as unet: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -296,8 +319,6 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
) )
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): # TODO: Verify the noise is the right size
# TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), noise=noise,
noise=noise, num_inference_steps=self.steps,
num_inference_steps=self.steps, conditioning_data=conditioning_data,
conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData]
control_data=control_data, # list[ControlNetData] callback=step_callback,
callback=step_callback, )
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
}, },
} }
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model( def _lora_loader():
**self.unet.unet.dict(), for lora in self.unet.loras:
) lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
with unet_info as unet: unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype latent, device=unet.device, dtype=latent.dtype)
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device, device=unet.device,
) )
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): timesteps=timesteps,
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( noise=noise,
latents=initial_latents, num_inference_steps=self.steps,
timesteps=timesteps, conditioning_data=conditioning_data,
noise=noise, control_data=control_data, # list[ControlNetData]
num_inference_steps=self.steps, callback=step_callback
conditioning_data=conditioning_data, )
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
# copied from diffusers pipeline # copied from diffusers pipeline
latents = latents / vae.config.scaling_factor latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0] image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
height=image_dto.height, height=image_dto.height,
) )
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
] "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") description="The latents to resize")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") width: int = Field(
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") ge=64, multiple_of=8, description="The width to resize to (px)")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") height: int = Field(
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, size=(self.height // 8, self.width // 8),
size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: Optional[LatentsField] = Field(
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") description="The latents to scale")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") scale_factor: float = Field(
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, latents, scale_factor=self.scale_factor, mode=self.mode,
scale_factor=self.scale_factor, antialias=self.antialias
mode=self.mode, if self.mode in ["bilinear", "bicubic"] else False,)
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The image to encode") image: Union[ImageField, None] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import List, Literal, Optional from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -12,35 +12,42 @@ class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel") model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo): class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model") weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["model_loader_output"] = "model_loader_output" type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on # fmt: on
class MainModelField(BaseModel): class MainModelField(BaseModel):
@ -50,6 +57,13 @@ class MainModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "Model Loader", "title": "Model Loader",
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {"model": "model"},
"model": "model"
}
}, },
} }
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
model_type = ModelType.Main model_type = ModelType.Main
@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation):
) )
""" """
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation):
model_type=model_type, model_type=model_type,
submodel=SubModelType.Vae, submodel=SubModelType.Vae,
), ),
) ),
) )
class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output" type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel") unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on # fmt: on
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader" type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name") lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora") weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora")
@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "Lora Loader", "title": "Lora Loader",
"tags": ["lora", "loader"], "tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
}, },
} }
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
# TODO: ui rewrite base_model = self.lora.base_model
base_model = BaseModelType.StableDiffusion1 lora_name = self.lora.model_name
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
): ):
raise Exception(f"Unkown lora name: {self.lora_name}!") raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): if self.unet is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): if self.clip is not None and any(
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput() output = LoraLoaderOutput()
@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
base_model=base_model, base_model=base_model,
model_name=self.lora_name, model_name=lora_name,
model_type=ModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
@ -230,23 +251,27 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
class VAEModelField(BaseModel): class VAEModelField(BaseModel):
"""Vae model field""" """Vae model field"""
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput): class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
#fmt: off # fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output" type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model") vae: VaeField = Field(default=None, description="Vae model")
#fmt: on # fmt: on
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader" type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load") vae_model: VAEModelField = Field(description="The VAE to load")
@ -257,9 +282,7 @@ class VaeLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "VAE Loader", "title": "VAE Loader",
"tags": ["vae", "loader"], "tags": ["vae", "loader"],
"type_hints": { "type_hints": {"vae_model": "vae_model"},
"vae_model": "vae_model"
}
}, },
} }
@ -269,17 +292,17 @@ class VaeLoaderInvocation(BaseInvocation):
model_type = ModelType.Vae model_type = ModelType.Vae
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model, base_model=base_model,
model_name=model_name, model_name=model_name,
model_type=model_type, model_type=model_type,
): ):
raise Exception(f"Unkown vae name: {model_name}!") raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput( return VaeLoaderOutput(
vae=VaeField( vae=VaeField(
vae = ModelInfo( vae=ModelInfo(
model_name = model_name, model_name=model_name,
base_model = base_model, base_model=base_model,
model_type = model_type, model_type=model_type,
) )
) )
) )

View File

@ -367,7 +367,8 @@ setting environment variables INVOKEAI_<setting>.
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')

View File

@ -255,6 +255,8 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(config,'max_cache_size') \ if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5 else config.max_loaded_models * 2.5
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance sequential_offload = config.sequential_guidance
self.mgr = ModelManager( self.mgr = ModelManager(

View File

@ -76,6 +76,10 @@ class MigrateTo3(object):
Create a unique name for a model for use within models.yaml. Create a unique name for a model for use within models.yaml.
''' '''
done = False done = False
# some model names have slashes in them, which really screws things up
name = name.replace('/','_')
key = ModelManager.create_key(name,info.base_type,info.model_type) key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key unique_name = key
counter = 1 counter = 1
@ -219,11 +223,12 @@ class MigrateTo3(object):
repo_id = 'openai/clip-vit-large-patch14' repo_id = 'openai/clip-vit-large-patch14'
self._migrate_pretrained(CLIPTokenizer, self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id, repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', dest= target_dir / 'clip-vit-large-patch14',
**kwargs) **kwargs)
self._migrate_pretrained(CLIPTextModel, self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id, repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', dest = target_dir / 'clip-vit-large-patch14',
force = True,
**kwargs) **kwargs)
# sd-2 # sd-2
@ -287,21 +292,21 @@ class MigrateTo3(object):
def _model_probe_to_path(self, info: ModelProbeInfo)->Path: def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value) return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
if dest.exists(): if dest.exists() and not force:
logger.info(f'Skipping existing {dest}') logger.info(f'Skipping existing {dest}')
return return
model = model_class.from_pretrained(repo_id, **kwargs) model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest) self._save_pretrained(model, dest, overwrite=force)
def _save_pretrained(self, model, dest: Path): def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name model_name = dest.name
download_path = dest.with_name(f'{model_name}.downloading') if overwrite:
model.save_pretrained(download_path, safe_serialization=True) model.save_pretrained(dest, safe_serialization=True)
download_path.replace(dest) else:
download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path: def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
@ -569,8 +574,10 @@ script, which will perform a full upgrade in place."""
dest_directory = args.dest_directory dest_directory = args.dest_directory
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." # TODO: revisit
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
do_migrate(root_directory,dest_directory) do_migrate(root_directory,dest_directory)

View File

@ -236,7 +236,6 @@ class ModelInstall(object):
) )
def _install_url(self, url: str)->AddModelResult: def _install_url(self, url: str)->AddModelResult:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
if not location: if not location:

View File

@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType from .models import BaseModelType, ModelVariantType
try: try:
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint = load_file(checkpoint_path) checkpoint = load_file(checkpoint_path)
else: else:
if scan_needed: if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path) # scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not

View File

@ -1,16 +1,17 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List from typing import Optional, Dict, Tuple, Any, Union, List
import torch from pathlib import Path
from safetensors.torch import load_file
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from torch.utils.hooks import RemovableHandle
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase: class LoRALayerBase:
#rank: Optional[int] #rank: Optional[int]
@ -537,9 +538,10 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu # enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"): #with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer.to(dtype=torch.float32)
layer_weight = layer.get_weight() * lora_weight * layer_scale layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
@ -653,6 +655,9 @@ class TextualInversionModel:
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -100,8 +100,6 @@ class ModelCache(object):
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
#max_cache_size = 9999 #max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload #self.sequential_offload: bool=sequential_offload

View File

@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelConfigBase, ModelNotFoundException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -409,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models: if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type) self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models: if model_key not in self.models:
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
@ -421,7 +421,7 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override # vae/movq override
# TODO: # TODO:
@ -798,12 +798,12 @@ class ModelManager(object):
if model_path.is_relative_to(self.app_config.root_path): if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path) model_path = model_path.relative_to(self.app_config.root_path)
try: try:
model_config: ModelConfigBase = model_class.probe_config(str(model_path)) model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)
imported_models = self.autoimport() imported_models = self.autoimport()

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint( model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path, file_path=checkpoint_path,
dtype=torch_dtype, dtype=torch_dtype,
) )

View File

@ -1,4 +1,8 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react'; import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd'; import { TypesafeDraggableData } from './typesafeDnd';
@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = {
}, },
}; };
const selector = createSelector(
stateSelector,
(state) => {
const gallerySelectionCount = state.gallery.selection.length;
const batchSelectionCount = state.batch.selection.length;
return {
gallerySelectionCount,
batchSelectionCount,
};
},
defaultSelectorOptions
);
const DragPreview = (props: OverlayDragImageProps) => { const DragPreview = (props: OverlayDragImageProps) => {
const { gallerySelectionCount, batchSelectionCount } =
useAppSelector(selector);
if (!props.dragData) { if (!props.dragData) {
return; return;
} }
@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
); );
} }
if (props.dragData.payloadType === 'IMAGE_NAMES') { if (props.dragData.payloadType === 'BATCH_SELECTION') {
return ( return (
<Flex <Flex
sx={{ sx={{
@ -70,7 +91,26 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES, ...STYLES,
}} }}
> >
<Heading>{props.dragData.payload.imageNames.length}</Heading> <Heading>{batchSelectionCount}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{gallerySelectionCount}</Heading>
<Heading size="sm">Images</Heading> <Heading size="sm">Images</Heading>
</Flex> </Flex>
); );

View File

@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO }; payload: { imageDTO: ImageDTO };
}; };
export type ImageNamesDraggableData = BaseDragData & { export type GallerySelectionDraggableData = BaseDragData & {
payloadType: 'IMAGE_NAMES'; payloadType: 'GALLERY_SELECTION';
payload: { imageNames: string[] }; };
export type BatchSelectionDraggableData = BaseDragData & {
payloadType: 'BATCH_SELECTION';
}; };
export type TypesafeDraggableData = export type TypesafeDraggableData =
| ImageDraggableData | ImageDraggableData
| ImageNamesDraggableData; | GallerySelectionDraggableData
| BatchSelectionDraggableData;
interface UseDroppableTypesafeArguments interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> { extends Omit<UseDroppableArguments, 'data'> {
@ -155,11 +159,13 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE': case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE': case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'ADD_TO_BATCH': case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'MOVE_BOARD': case 'MOVE_BOARD':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; return (
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
);
default: default:
return false; return false;
} }

View File

@ -20,10 +20,8 @@ const serializationDenylist: {
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist, ui: uiPersistDenylist,
controlNet: controlNetDenylist, controlNet: controlNetDenylist,
// hotkeys: hotkeysPersistDenylist,
}; };
export const serialize: SerializeFunction = (data, key) => { export const serialize: SerializeFunction = (data, key) => {

View File

@ -1,21 +1,21 @@
import { startAppListening } from '..';
import { imageDeleted } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import {
imageSelected,
imageRemoved,
selectImagesIds,
} from 'features/gallery/store/gallerySlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import {
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; imageRemoved,
import { api } from 'services/api'; imageSelected,
selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
import { import {
imageDeletionConfirmed, imageDeletionConfirmed,
isModalOpenChanged, isModalOpenChanged,
} from 'features/imageDeletion/store/imageDeletionSlice'; } from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api';
import { imageDeleted } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => {
state.gallery.selection[state.gallery.selection.length - 1]; state.gallery.selection[state.gallery.selection.length - 1];
if (lastSelectedImage === image_name) { if (lastSelectedImage === image_name) {
const ids = selectImagesIds(state); const filteredImages = selectFilteredImages(state);
const ids = filteredImages.map((i) => i.image_name);
const deletedImageIndex = ids.findIndex( const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name

View File

@ -1,24 +1,23 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { startAppListening } from '../';
import { log } from 'app/logging/useLogger';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { log } from 'app/logging/useLogger';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { import {
imageAddedToBatch, imageAddedToBatch,
imagesAddedToBatch, imagesAddedToBatch,
} from 'features/batch/store/batchSlice'; } from 'features/batch/store/batchSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
fieldValueChanged, fieldValueChanged,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
import { boardsApi } from 'services/api/endpoints/boards'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { startAppListening } from '../';
const moduleLog = log.child({ namespace: 'dnd' }); const moduleLog = log.child({ namespace: 'dnd' });
@ -33,6 +32,7 @@ export const addImageDroppedListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload; const { activeData, overData } = action.payload;
const { actionType } = overData; const { actionType } = overData;
const state = getState();
// set current image // set current image
if ( if (
@ -64,9 +64,9 @@ export const addImageDroppedListener = () => {
// add multiple images to batch // add multiple images to batch
if ( if (
actionType === 'ADD_TO_BATCH' && actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES' activeData.payloadType === 'GALLERY_SELECTION'
) { ) {
dispatch(imagesAddedToBatch(activeData.payload.imageNames)); dispatch(imagesAddedToBatch(state.gallery.selection));
} }
// set control image // set control image
@ -128,14 +128,14 @@ export const addImageDroppedListener = () => {
// set multiple nodes images (multiple images handler) // set multiple nodes images (multiple images handler)
if ( if (
actionType === 'SET_MULTI_NODES_IMAGE' && actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_NAMES' activeData.payloadType === 'GALLERY_SELECTION'
) { ) {
const { fieldName, nodeId } = overData.context; const { fieldName, nodeId } = overData.context;
dispatch( dispatch(
imageCollectionFieldValueChanged({ imageCollectionFieldValueChanged({
nodeId, nodeId,
fieldName, fieldName,
value: activeData.payload.imageNames.map((image_name) => ({ value: state.gallery.selection.map((image_name) => ({
image_name, image_name,
})), })),
}) })

View File

@ -8,31 +8,32 @@ import {
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { rememberEnhancer, rememberReducer } from 'redux-remember';
import batchReducer from 'features/batch/store/batchSlice';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import batchReducer from 'features/batch/store/batchSlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { api } from 'services/api';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { LOCALSTORAGE_PREFIX } from './constants'; import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { api } from 'services/api'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
const allReducers = { const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
@ -50,6 +51,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer, batch: batchReducer,
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'batch', 'batch',
'lora',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',

View File

@ -4,22 +4,25 @@ import {
Collapse, Collapse,
Flex, Flex,
Spacer, Spacer,
Switch, Text,
useColorMode, useColorMode,
useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo } from 'react'; import { PropsWithChildren, memo } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
export type IAIToggleCollapseProps = PropsWithChildren & { export type IAIToggleCollapseProps = PropsWithChildren & {
label: string; label: string;
isOpen: boolean; activeLabel?: string;
onToggle: () => void; defaultIsOpen?: boolean;
withSwitch?: boolean;
}; };
const IAICollapse = (props: IAIToggleCollapseProps) => { const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props; const { label, activeLabel, children, defaultIsOpen = false } = props;
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen });
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
return ( return (
<Box> <Box>
<Flex <Flex
@ -28,6 +31,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
alignItems: 'center', alignItems: 'center',
p: 2, p: 2,
px: 4, px: 4,
gap: 2,
borderTopRadius: 'base', borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base', borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen bg: isOpen
@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
}} }}
> >
{label} {label}
<AnimatePresence>
{activeLabel && (
<motion.div
key="statusText"
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<Text
sx={{ color: 'accent.500', _dark: { color: 'accent.300' } }}
>
{activeLabel}
</Text>
</motion.div>
)}
</AnimatePresence>
<Spacer /> <Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />} <ChevronUpIcon
{!withSwitch && ( sx={{
<ChevronUpIcon w: '1rem',
sx={{ h: '1rem',
w: '1rem', transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
h: '1rem', transitionProperty: 'common',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)', transitionDuration: 'normal',
transitionProperty: 'common', }}
transitionDuration: 'normal', />
}}
/>
)}
</Flex> </Flex>
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}> <Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box <Box

View File

@ -61,7 +61,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&:focus-within': { '&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode), borderColor: mode(accent200, accent600)(colorMode),
}, },
'&:disabled': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
}, },

View File

@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus-within': { '&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode), borderColor: mode(accent200, accent600)(colorMode),
}, },
'&:disabled': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
}, },

View File

@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => {
isDisabled={isDisabled} isDisabled={isDisabled}
width={width} width={width}
display="flex" display="flex"
gap={4}
alignItems="center" alignItems="center"
{...formControlProps} {...formControlProps}
> >
@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => {
sx={{ sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer', cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx, ...formLabelProps?.sx,
pe: 4,
}} }}
{...formLabelProps} {...formLabelProps}
> >

View File

@ -1,28 +1,29 @@
import { Box, Icon, Skeleton } from '@chakra-ui/react'; import { Box, Icon, Skeleton } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FaExclamationCircle } from 'react-icons/fa'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import IAIDndImage from 'common/components/IAIDndImage';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { import {
batchImageRangeEndSelected, batchImageRangeEndSelected,
batchImageSelected, batchImageSelected,
batchImageSelectionToggled, batchImageSelectionToggled,
imageRemovedFromBatch, imageRemovedFromBatch,
} from 'features/batch/store/batchSlice'; } from 'features/batch/store/batchSlice';
import IAIDndImage from 'common/components/IAIDndImage'; import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { createSelector } from '@reduxjs/toolkit'; import { FaExclamationCircle } from 'react-icons/fa';
import { RootState, stateSelector } from 'app/store/store'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
const isSelectedSelector = createSelector( const makeSelector = (image_name: string) =>
[stateSelector, (state: RootState, imageName: string) => imageName], createSelector(
(state, imageName) => ({ [stateSelector],
selection: state.batch.selection, (state) => ({
isSelected: state.batch.selection.includes(imageName), selectionCount: state.batch.selection.length,
}), isSelected: state.batch.selection.includes(image_name),
defaultSelectorOptions }),
); defaultSelectorOptions
);
type BatchImageProps = { type BatchImageProps = {
imageName: string; imageName: string;
@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => {
} = useGetImageDTOQuery(props.imageName); } = useGetImageDTOQuery(props.imageName);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isSelected, selection } = useAppSelector((state) => const selector = useMemo(
isSelectedSelector(state, props.imageName) () => makeSelector(props.imageName),
[props.imageName]
); );
const { isSelected, selectionCount } = useAppSelector(selector);
const handleClickRemove = useCallback(() => { const handleClickRemove = useCallback(() => {
dispatch(imageRemovedFromBatch(props.imageName)); dispatch(imageRemovedFromBatch(props.imageName));
}, [dispatch, props.imageName]); }, [dispatch, props.imageName]);
@ -59,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => {
); );
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => { const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) { if (selectionCount > 1) {
return { return {
id: 'batch', id: 'batch',
payloadType: 'IMAGE_NAMES', payloadType: 'BATCH_SELECTION',
payload: {
imageNames: selection,
},
}; };
} }
@ -76,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => {
payload: { imageDTO }, payload: { imageDTO },
}; };
} }
}, [imageDTO, selection]); }, [imageDTO, selectionCount]);
if (isError) { if (isError) {
return <Icon as={FaExclamationCircle} />; return <Icon as={FaExclamationCircle} />;

View File

@ -1,25 +1,22 @@
import { memo, useCallback, useMemo, useState } from 'react';
import { ImageDTO } from 'services/api/types';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
@ -83,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => {
} }
}, [controlImage, controlNetId]); }, [controlImage, controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => { const droppableData = useMemo<TypesafeDroppableData | undefined>(
if (controlNetId) { () => ({
return { id: controlNetId,
id: controlNetId, actionType: 'SET_CONTROLNET_IMAGE',
actionType: 'SET_CONTROLNET_IMAGE', context: { controlNetId },
context: { controlNetId }, }),
}; [controlNetId]
} );
}, [controlNetId]);
const postUploadAction = useMemo<PostUploadAction>( const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), () => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.controlNet;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamControlNetFeatureToggle = () => {
const { isEnabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(isControlNetEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
/>
);
};
export default ParamControlNetFeatureToggle;

View File

@ -0,0 +1,15 @@
import { filter } from 'lodash-es';
import { ControlNetConfig } from '../store/controlNetSlice';
export const getValidControlNets = (
controlNets: Record<string, ControlNetConfig>
) => {
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
return validControlNets;
};

View File

@ -1,40 +1,30 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react'; import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { isEnabled } = state.dynamicPrompts; const { isEnabled } = state.dynamicPrompts;
return { isEnabled }; return { activeLabel: isEnabled ? 'Enabled' : undefined };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsCollapse = () => { const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch(); const { activeLabel } = useAppSelector(selector);
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return ( return (
<IAICollapse <IAICollapse label="Dynamic Prompts" activeLabel={activeLabel}>
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}> <Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsToggle />
<ParamDynamicPromptsCombinatorial /> <ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts /> <ParamDynamicPromptsMaxPrompts />
</Flex> </Flex>

View File

@ -1,23 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { combinatorialToggled } from '../store/slice';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { combinatorial } = state.dynamicPrompts; const { combinatorial, isEnabled } = state.dynamicPrompts;
return { combinatorial }; return { combinatorial, isDisabled: !isEnabled };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsCombinatorial = () => { const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector); const { combinatorial, isDisabled } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChange = useCallback(() => { const handleChange = useCallback(() => {
@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => {
return ( return (
<IAISwitch <IAISwitch
isDisabled={isDisabled}
label="Combinatorial Generation" label="Combinatorial Generation"
isChecked={combinatorial} isChecked={combinatorial}
onChange={handleChange} onChange={handleChange}

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsToggle = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable Dynamic Prompts"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
);
};
export default ParamDynamicPromptsToggle;

View File

@ -1,25 +1,31 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useCallback } from 'react';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state) => { (state) => {
const { maxPrompts, combinatorial } = state.dynamicPrompts; const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts;
const { min, sliderMax, inputMax } = const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts; state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax, combinatorial }; return {
maxPrompts,
min,
sliderMax,
inputMax,
isDisabled: !isEnabled || !combinatorial,
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamDynamicPromptsMaxPrompts = () => { const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax, combinatorial } = const { maxPrompts, min, sliderMax, inputMax, isDisabled } =
useAppSelector(selector); useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return ( return (
<IAISlider <IAISlider
label="Max Prompts" label="Max Prompts"
isDisabled={!combinatorial} isDisabled={isDisabled}
min={min} min={min}
max={sliderMax} max={sliderMax}
value={maxPrompts} value={maxPrompts}

View File

@ -1,19 +1,19 @@
import { Box, Flex, Image } from '@chakra-ui/react'; import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { isEqual } from 'lodash-es';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useMemo } from 'react';
import IAIDndImage from 'common/components/IAIDndImage';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { stateSelector } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[stateSelector, selectLastSelectedImage], [stateSelector, selectLastSelectedImage],

View File

@ -1,34 +1,35 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import ImageContextMenu from './ImageContextMenu'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import { import {
imageRangeEndSelected, imageRangeEndSelected,
imageSelected, imageSelected,
imageSelectionToggled, imageSelectionToggled,
} from '../store/gallerySlice'; } from '../store/gallerySlice';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; import ImageContextMenu from './ImageContextMenu';
export const selector = createSelector( export const makeSelector = (image_name: string) =>
[stateSelector, (state, { image_name }: ImageDTO) => image_name], createSelector(
({ gallery }, image_name) => { [stateSelector],
const isSelected = gallery.selection.includes(image_name); ({ gallery }) => {
const selection = gallery.selection; const isSelected = gallery.selection.includes(image_name);
return { const selectionCount = gallery.selection.length;
isSelected, return {
selection, isSelected,
}; selectionCount,
}, };
defaultSelectorOptions },
); defaultSelectorOptions
);
interface HoverableImageProps { interface HoverableImageProps {
imageDTO: ImageDTO; imageDTO: ImageDTO;
@ -38,13 +39,13 @@ interface HoverableImageProps {
* Gallery image component with delete/use all/use seed buttons on hover. * Gallery image component with delete/use all/use seed buttons on hover.
*/ */
const GalleryImage = (props: HoverableImageProps) => { const GalleryImage = (props: HoverableImageProps) => {
const { isSelected, selection } = useAppSelector((state) =>
selector(state, props.imageDTO)
);
const { imageDTO } = props; const { imageDTO } = props;
const { image_url, thumbnail_url, image_name } = imageDTO; const { image_url, thumbnail_url, image_name } = imageDTO;
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
const { isSelected, selectionCount } = useAppSelector(localSelector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -74,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => {
); );
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => { const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) { if (selectionCount > 1) {
return { return {
id: 'gallery-image', id: 'gallery-image',
payloadType: 'IMAGE_NAMES', payloadType: 'GALLERY_SELECTION',
payload: { imageNames: selection },
}; };
} }
@ -89,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => {
payload: { imageDTO }, payload: { imageDTO },
}; };
} }
}, [imageDTO, selection]); }, [imageDTO, selectionCount]);
return ( return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}> <Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>

View File

@ -7,7 +7,6 @@ import {
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice';
import { keyBy, uniq } from 'lodash-es'; import { keyBy, uniq } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
import { import {
@ -174,11 +173,6 @@ export const gallerySlice = createSlice({
state.limit = limit; state.limit = limit;
state.total = total; state.total = total;
}); });
builder.addCase(imageDeletionConfirmed, (state, action) => {
// Image deleted
const { image_name } = action.payload.imageDTO;
imagesAdapter.removeOne(state, image_name);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload; const { image_name, image_url, thumbnail_url } = action.payload;

View File

@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
import { import {
imageDeletionConfirmed, imageDeletionConfirmed,
imageToDeleteCleared, imageToDeleteCleared,
isModalOpenChanged,
selectImageUsage, selectImageUsage,
} from '../store/imageDeletionSlice'; } from '../store/imageDeletionSlice';
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
const handleClose = useCallback(() => { const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared()); dispatch(imageToDeleteCleared());
dispatch(isModalOpenChanged(false));
}, [dispatch]); }, [dispatch]);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {

View File

@ -31,6 +31,7 @@ const imageDeletion = createSlice({
}, },
imageToDeleteCleared: (state) => { imageToDeleteCleared: (state) => {
state.imageToDelete = null; state.imageToDelete = null;
state.isModalOpen = false;
}, },
}, },
}); });

View File

@ -0,0 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
type Props = {
lora: Lora;
};
const ParamLora = (props: Props) => {
const dispatch = useAppDispatch();
const { lora } = props;
const handleChange = useCallback(
(v: number) => {
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
},
[dispatch, lora.id]
);
const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
}, [dispatch, lora.id]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]);
return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider
label={lora.name}
value={lora.weight}
onChange={handleChange}
min={-1}
max={2}
step={0.01}
withInput
withReset
handleReset={handleReset}
withSliderMarks
sliderMarks={[-1, 0, 1, 2]}
/>
<IAIIconButton
size="sm"
onClick={handleRemoveLora}
tooltip="Remove LoRA"
aria-label="Remove LoRA"
icon={<FaTrash />}
colorScheme="error"
/>
</Flex>
);
};
export default memo(ParamLora);

View File

@ -0,0 +1,36 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { size } from 'lodash-es';
import { memo } from 'react';
import ParamLoraList from './ParamLoraList';
import ParamLoraSelect from './ParamLoraSelect';
const selector = createSelector(
stateSelector,
(state) => {
const loraCount = size(state.lora.loras);
return {
activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined,
};
},
defaultSelectorOptions
);
const ParamLoraCollapse = () => {
const { activeLabel } = useAppSelector(selector);
return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoraSelect />
<ParamLoraList />
</Flex>
</IAICollapse>
);
};
export default memo(ParamLoraCollapse);

View File

@ -0,0 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import ParamLora from './ParamLora';
const selector = createSelector(
stateSelector,
({ lora }) => {
const { loras } = lora;
return { loras };
},
defaultSelectorOptions
);
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
};
export default ParamLoraList;

View File

@ -0,0 +1,107 @@
import { Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import { forwardRef, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
type LoraSelectItem = {
label: string;
value: string;
description?: string;
};
const selector = createSelector(
stateSelector,
({ lora }) => ({
loras: lora.loras,
}),
defaultSelectorOptions
);
const ParamLoraSelect = () => {
const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useGetLoRAModelsQuery();
const data = useMemo(() => {
if (!lorasQueryData) {
return [];
}
const data: LoraSelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) {
return;
}
data.push({
value: id,
label: lora.name,
description: lora.description,
});
});
return data;
}, [loras, lorasQueryData]);
const handleChange = useCallback(
(v: string[]) => {
const loraEntity = lorasQueryData?.entities[v[0]];
if (!loraEntity) {
return;
}
v[0] && dispatch(loraAdded(loraEntity));
},
[dispatch, lorasQueryData?.entities]
);
return (
<IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching LoRAs"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
);
};
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default ParamLoraSelect;

View File

@ -0,0 +1,46 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type Lora = {
id: string;
name: string;
weight: number;
};
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
weight: 1,
};
export type LoraState = {
loras: Record<string, Lora>;
};
export const intialLoraState: LoraState = {
loras: {},
};
export const loraSlice = createSlice({
name: 'lora',
initialState: intialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id } = action.payload;
state.loras[id] = { id, name, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
},
loraWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
) => {
const { id, weight } = action.payload;
state.loras[id].weight = weight;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'lora_model' && template.type === 'lora_model') {
return (
<LoRAModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') { if (type === 'array' && template.type === 'array') {
return ( return (
<ArrayInputFieldComponent <ArrayInputFieldComponent

View File

@ -7,18 +7,16 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api/types';
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { import {
NodesImageDropData,
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import IAIDndImage from 'common/components/IAIDndImage';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { FieldComponentProps } from './types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate> props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
@ -34,23 +32,6 @@ const ImageInputFieldComponent = (
isSuccess, isSuccess,
} = useGetImageDTOQuery(field.value?.image_name ?? skipToken); } = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
const handleDrop = useCallback(
({ image_name }: ImageDTO) => {
if (field.value?.image_name === image_name) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: { image_name },
})
);
},
[dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
dispatch( dispatch(
fieldValueChanged({ fieldValueChanged({
@ -71,15 +52,14 @@ const ImageInputFieldComponent = (
} }
}, [field.name, imageDTO, nodeId]); }, [field.name, imageDTO, nodeId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => { const droppableData = useMemo<TypesafeDroppableData | undefined>(
if (imageDTO) { () => ({
return { id: `node-${nodeId}-${field.name}`,
id: `node-${nodeId}-${field.name}`, actionType: 'SET_NODES_IMAGE',
actionType: 'SET_NODES_IMAGE', context: { nodeId, fieldName: field.name },
context: { nodeId, fieldName: field.name }, }),
}; [field.name, nodeId]
} );
}, [field.name, imageDTO, nodeId]);
const postUploadAction = useMemo<PostUploadAction>( const postUploadAction = useMemo<PostUploadAction>(
() => ({ () => ({

View File

@ -0,0 +1,102 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const LoRAModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: loraModels } = useGetLoRAModelsQuery();
const selectedModel = useMemo(
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
[loraModels?.entities, loraModels?.ids, field.value]
);
const data = useMemo(() => {
if (!loraModels) {
return [];
}
const data: SelectItem[] = [];
forEach(loraModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [loraModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && loraModels?.ids.includes(field.value)) {
return;
}
const firstLora = loraModels?.ids[0];
if (!isString(firstLora)) {
return;
}
handleValueChanged(firstLora);
}, [field.value, handleValueChanged, loraModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(LoRAModelInputFieldComponent);

View File

@ -11,7 +11,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
@ -22,9 +22,7 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: mainModels } = useListModelsQuery({ const { data: mainModels } = useGetMainModelsQuery();
model_type: 'main',
});
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {

View File

@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const VaeModelInputFieldComponent = ( const VaeModelInputFieldComponent = (
@ -24,9 +24,7 @@ const VaeModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({ const { data: vaeModels } = useGetVaeModelsQuery();
model_type: 'vae',
});
const selectedModel = useMemo( const selectedModel = useMemo(
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]], () => vaeModels?.entities[field.value ?? vaeModels.ids[0]],

View File

@ -1,5 +1,8 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { cloneDeep, uniqBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { import {
addEdge, addEdge,
applyEdgeChanges, applyEdgeChanges,
@ -11,12 +14,9 @@ import {
NodeChange, NodeChange,
OnConnectStartParams, OnConnectStartParams,
} from 'reactflow'; } from 'reactflow';
import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types';
import { InvocationTemplate, InvocationValue } from '../types/types'; import { InvocationTemplate, InvocationValue } from '../types/types';
import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store';
import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es';
export type NodesState = { export type NodesState = {
nodes: Node<InvocationValue>[]; nodes: Node<InvocationValue>[];

View File

@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
VaeField: 'vae', VaeField: 'vae',
model: 'model', model: 'model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model',
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
@ -120,7 +121,13 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
vae_model: { vae_model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),
title: 'Model', title: 'VAE',
description: 'Models are models.',
},
lora_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'LoRA',
description: 'Models are models.', description: 'Models are models.',
}, },
array: { array: {

View File

@ -65,6 +65,7 @@ export type FieldType =
| 'control' | 'control'
| 'model' | 'model'
| 'vae_model' | 'vae_model'
| 'lora_model'
| 'array' | 'array'
| 'item' | 'item'
| 'color' | 'color'
@ -93,6 +94,7 @@ export type InputFieldValue =
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | ModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
| ItemInputFieldValue | ItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
@ -119,6 +121,7 @@ export type InputFieldTemplate =
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
| ItemInputFieldTemplate | ItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & {
value?: string; value?: string;
}; };
export type LoRAModelInputFieldValue = FieldValueBase & {
type: 'lora_model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {
type: 'array'; type: 'array';
value?: (string | number)[]; value?: (string | number)[];
@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'vae_model'; type: 'vae_model';
}; };
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'lora_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'array'; type: 'array';

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { filter } from 'lodash-es'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = (
): void => { ): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = filter( const validControlNets = getValidControlNets(controlNets);
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
if (isControlNetEnabled && Boolean(validControlNets.length)) { if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) { if (validControlNets.length > 1) {

View File

@ -18,6 +18,7 @@ import {
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
ItemInputFieldTemplate, ItemInputFieldTemplate,
LatentsInputFieldTemplate, LatentsInputFieldTemplate,
LoRAModelInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
OutputFieldTemplate, OutputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({
return template; return template;
}; };
const buildLoRAModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
const template: LoRAModelInputFieldTemplate = {
...baseField,
type: 'lora_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -460,6 +476,9 @@ export const buildInputFieldTemplate = (
if (['vae_model'].includes(fieldType)) { if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) { if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField }); return buildEnumInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -79,6 +79,10 @@ export const buildInputFieldValue = (
if (template.type === 'vae_model') { if (template.type === 'vae_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
} }
return fieldValue; return fieldValue;

View File

@ -0,0 +1,148 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import {
LORA_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
} from './constants';
export const addLoRAsToGraph = (
graph: NonNullableGraph,
state: RootState,
baseNodeId: string
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
if (loraCount > 0) {
// remove any existing connections from main model loader, we need to insert the lora nodes
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet', 'clip'].includes(e.source.field)
)
);
}
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { id, name, weight } = lora;
const loraField = modelIdToLoRAModelField(id);
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
'.',
'_'
)}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
lora: loraField,
weight,
};
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});
}
// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};

View File

@ -9,6 +9,7 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(graph, state);

View File

@ -8,6 +8,7 @@ import {
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
INPAINT, INPAINT,
@ -194,6 +195,8 @@ export const buildCanvasInpaintGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, INPAINT);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(graph, state);

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(graph, state);

View File

@ -10,6 +10,7 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
IMAGE_COLLECTION, IMAGE_COLLECTION,
@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = (
}, },
}); });
} }
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE // Add VAE
addVAEToGraph(graph, state); addVAEToGraph(graph, state);

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add Custom VAE Support // Add Custom VAE Support
addVAEToGraph(graph, state); addVAEToGraph(graph, state);

View File

@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
} }
} }
if (field.type === 'lora_model') {
if (field.value) {
return modelIdToLoRAModelField(field.value);
}
}
return field.value; return field.value;
}; };

View File

@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MAIN_MODEL_LOADER = 'main_model_loader'; export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader'; export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,12 @@
import { BaseModelType, LoRAModelField } from 'services/api/types';
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
const [base_model, model_type, model_name] = loraId.split('/');
const field: LoRAModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,20 +1,15 @@
import { Flex, useDisclosure } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react'; import { memo } from 'react';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; import { useTranslation } from 'react-i18next';
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight'; import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
const ParamBoundingBoxCollapse = () => { const ParamBoundingBoxCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return ( return (
<IAICollapse <IAICollapse label={t('parameters.boundingBoxHeader')}>
label={t('parameters.boundingBoxHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamBoundingBoxWidth /> <ParamBoundingBoxWidth />
<ParamBoundingBoxHeight /> <ParamBoundingBoxHeight />

View File

@ -1,4 +1,4 @@
import { Flex, useDisclosure } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse';
import ParamInfillMethod from './ParamInfillMethod'; import ParamInfillMethod from './ParamInfillMethod';
import ParamInfillTilesize from './ParamInfillTilesize'; import ParamInfillTilesize from './ParamInfillTilesize';
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing'; import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
import ParamScaledWidth from './ParamScaledWidth';
import ParamScaledHeight from './ParamScaledHeight'; import ParamScaledHeight from './ParamScaledHeight';
import ParamScaledWidth from './ParamScaledWidth';
const ParamInfillCollapse = () => { const ParamInfillCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return ( return (
<IAICollapse <IAICollapse label={t('parameters.infillScalingHeader')}>
label={t('parameters.infillScalingHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamInfillMethod /> <ParamInfillMethod />
<ParamInfillTilesize /> <ParamInfillTilesize />

View File

@ -1,22 +1,16 @@
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamBlur from './ParamSeamBlur'; import ParamSeamBlur from './ParamSeamBlur';
import ParamSeamSize from './ParamSeamSize'; import ParamSeamSize from './ParamSeamSize';
import ParamSeamSteps from './ParamSeamSteps'; import ParamSeamSteps from './ParamSeamSteps';
import ParamSeamStrength from './ParamSeamStrength'; import ParamSeamStrength from './ParamSeamStrength';
import { useDisclosure } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
const ParamSeamCorrectionCollapse = () => { const ParamSeamCorrectionCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return ( return (
<IAICollapse <IAICollapse label={t('parameters.seamCorrectionHeader')}>
label={t('parameters.seamCorrectionHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<ParamSeamSize /> <ParamSeamSize />
<ParamSeamBlur /> <ParamSeamBlur />
<ParamSeamStrength /> <ParamSeamStrength />

View File

@ -1,41 +1,45 @@
import { Divider, Flex } from '@chakra-ui/react'; import { Divider, Flex } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse';
import { Fragment, memo, useCallback } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import { import {
controlNetAdded, controlNetAdded,
controlNetSelector, controlNetSelector,
isControlNetEnabledToggled,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { map } from 'lodash-es';
import { v4 as uuidv4 } from 'uuid';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import IAIButton from 'common/components/IAIButton'; import { map } from 'lodash-es';
import ControlNet from 'features/controlNet/components/ControlNet'; import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
(controlNet) => { (controlNet) => {
const { controlNets, isEnabled } = controlNet; const { controlNets, isEnabled } = controlNet;
return { controlNetsArray: map(controlNets), isEnabled }; const validControlNets = getValidControlNets(controlNets);
const activeLabel =
isEnabled && validControlNets.length > 0
? `${validControlNets.length} Active`
: undefined;
return { controlNetsArray: map(controlNets), activeLabel };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamControlNetCollapse = () => { const ParamControlNetCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { controlNetsArray, isEnabled } = useAppSelector(selector); const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleClickControlNetToggle = useCallback(() => {
dispatch(isControlNetEnabledToggled());
}, [dispatch]);
const handleClickedAddControlNet = useCallback(() => { const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() })); dispatch(controlNetAdded({ controlNetId: uuidv4() }));
}, [dispatch]); }, [dispatch]);
@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => {
} }
return ( return (
<IAICollapse <IAICollapse label="ControlNet" activeLabel={activeLabel}>
label={'ControlNet'}
isOpen={isEnabled}
onToggle={handleClickControlNetToggle}
withSwitch
>
<Flex sx={{ flexDir: 'column', gap: 3 }}> <Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle />
{controlNetsArray.map((c, i) => ( {controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}> <Fragment key={c.controlNetId}>
{i > 0 && <Divider />} {i > 0 && <Divider />}

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
@ -27,7 +28,8 @@ const selector = createSelector(
shouldUseSliders, shouldUseSliders,
shift, shift,
}; };
} },
defaultSelectorOptions
); );
const ParamCFGScale = () => { const ParamCFGScale = () => {

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setHeight } from 'features/parameters/store/generationSlice'; import { setHeight } from 'features/parameters/store/generationSlice';
@ -25,7 +26,8 @@ const selector = createSelector(
inputMax, inputMax,
step, step,
}; };
} },
defaultSelectorOptions
); );
type ParamHeightProps = Omit< type ParamHeightProps = Omit<

View File

@ -1,37 +1,38 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setIterations } from 'features/parameters/store/generationSlice'; import { setIterations } from 'features/parameters/store/generationSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { hotkeysSelector } from 'features/ui/store/hotkeysSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector([stateSelector], (state) => { const selector = createSelector(
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = [stateSelector],
state.config.sd.iterations; (state) => {
const { iterations } = state.generation; const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
const { shouldUseSliders } = state.ui; state.config.sd.iterations;
const isDisabled = const { iterations } = state.generation;
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; const { shouldUseSliders } = state.ui;
const isDisabled =
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = state.hotkeys.shift ? fineStep : coarseStep; const step = state.hotkeys.shift ? fineStep : coarseStep;
return { return {
iterations, iterations,
initial, initial,
min, min,
sliderMax, sliderMax,
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
isDisabled, isDisabled,
}; };
}); },
defaultSelectorOptions
);
const ParamIterations = () => { const ParamIterations = () => {
const { const {

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
@ -33,7 +34,8 @@ const selector = createSelector(
step, step,
shouldUseSliders, shouldUseSliders,
}; };
} },
defaultSelectorOptions
); );
const ParamSteps = () => { const ParamSteps = () => {

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAIFullSliderProps } from 'common/components/IAISlider'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setWidth } from 'features/parameters/store/generationSlice'; import { setWidth } from 'features/parameters/store/generationSlice';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
@ -26,7 +26,8 @@ const selector = createSelector(
inputMax, inputMax,
step, step,
}; };
} },
defaultSelectorOptions
); );
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>; type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;

View File

@ -1,37 +1,39 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { stateSelector } from 'app/store/store';
import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { ParamHiresStrength } from './ParamHiresStrength';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ParamHiresStrength } from './ParamHiresStrength';
import { ParamHiresToggle } from './ParamHiresToggle';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamHiresCollapse = () => { const ParamHiresCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const hiresFix = useAppSelector( const { activeLabel } = useAppSelector(selector);
(state: RootState) => state.postprocessing.hiresFix
);
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled; const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setHiresFix(!hiresFix));
if (!isHiresEnabled) { if (!isHiresEnabled) {
return null; return null;
} }
return ( return (
<IAICollapse <IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
label={t('parameters.hiresOptim')}
isOpen={hiresFix}
onToggle={handleToggle}
withSwitch
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamHiresToggle />
<ParamHiresStrength /> <ParamHiresStrength />
</Flex> </Flex>
</IAICollapse> </IAICollapse>

View File

@ -23,7 +23,6 @@ export const ParamHiresToggle = () => {
return ( return (
<IAISwitch <IAISwitch
label={t('parameters.hiresOptim')} label={t('parameters.hiresOptim')}
fontSize="md"
isChecked={hiresFix} isChecked={hiresFix}
onChange={handleChangeHiresFix} onChange={handleChangeHiresFix}
/> />

View File

@ -1,27 +1,33 @@
import { useTranslation } from 'react-i18next';
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import ParamPerlinNoise from './ParamPerlinNoise';
import ParamNoiseThreshold from './ParamNoiseThreshold';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamNoiseThreshold from './ParamNoiseThreshold';
import { ParamNoiseToggle } from './ParamNoiseToggle';
import ParamPerlinNoise from './ParamPerlinNoise';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings } = state.generation;
return {
activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined,
};
},
defaultSelectorOptions
);
const ParamNoiseCollapse = () => { const ParamNoiseCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
const shouldUseNoiseSettings = useAppSelector( const { activeLabel } = useAppSelector(selector);
(state: RootState) => state.generation.shouldUseNoiseSettings
);
const dispatch = useAppDispatch();
const handleToggle = () =>
dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings));
if (!isNoiseEnabled) { if (!isNoiseEnabled) {
return null; return null;
@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => {
return ( return (
<IAICollapse <IAICollapse
label={t('parameters.noiseSettings')} label={t('parameters.noiseSettings')}
isOpen={shouldUseNoiseSettings} activeLabel={activeLabel}
onToggle={handleToggle}
withSwitch
> >
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamNoiseToggle />
<ParamPerlinNoise /> <ParamPerlinNoise />
<ParamNoiseThreshold /> <ParamNoiseThreshold />
</Flex> </Flex>

View File

@ -1,18 +1,31 @@
import { RootState } from 'app/store/store'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { setThreshold } from 'features/parameters/store/generationSlice'; import { setThreshold } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings, threshold } = state.generation;
return {
isDisabled: !shouldUseNoiseSettings,
threshold,
};
},
defaultSelectorOptions
);
export default function ParamNoiseThreshold() { export default function ParamNoiseThreshold() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const threshold = useAppSelector( const { threshold, isDisabled } = useAppSelector(selector);
(state: RootState) => state.generation.threshold
);
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<IAISlider <IAISlider
isDisabled={isDisabled}
label={t('parameters.noiseThreshold')} label={t('parameters.noiseThreshold')}
min={0} min={0}
max={20} max={20}

View File

@ -0,0 +1,27 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export const ParamNoiseToggle = () => {
const dispatch = useAppDispatch();
const shouldUseNoiseSettings = useAppSelector(
(state: RootState) => state.generation.shouldUseNoiseSettings
);
const { t } = useTranslation();
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseNoiseSettings(e.target.checked));
return (
<IAISwitch
label="Enable Noise Settings"
isChecked={shouldUseNoiseSettings}
onChange={handleChange}
/>
);
};

View File

@ -1,16 +1,31 @@
import { RootState } from 'app/store/store'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { setPerlin } from 'features/parameters/store/generationSlice'; import { setPerlin } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings, perlin } = state.generation;
return {
isDisabled: !shouldUseNoiseSettings,
perlin,
};
},
defaultSelectorOptions
);
export default function ParamPerlinNoise() { export default function ParamPerlinNoise() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const perlin = useAppSelector((state: RootState) => state.generation.perlin); const { perlin, isDisabled } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<IAISlider <IAISlider
isDisabled={isDisabled}
label={t('parameters.perlinNoise')} label={t('parameters.perlinNoise')}
min={0} min={0}
max={1} max={1}

View File

@ -1,36 +1,46 @@
import { useTranslation } from 'react-i18next';
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSeamless } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamlessXAxis from './ParamSeamlessXAxis'; import ParamSeamlessXAxis from './ParamSeamlessXAxis';
import ParamSeamlessYAxis from './ParamSeamlessYAxis'; import ParamSeamlessYAxis from './ParamSeamlessYAxis';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const getActiveLabel = (seamlessXAxis: boolean, seamlessYAxis: boolean) => {
if (seamlessXAxis && seamlessYAxis) {
return 'X & Y';
}
if (seamlessXAxis) {
return 'X';
}
if (seamlessYAxis) {
return 'Y';
}
};
const selector = createSelector( const selector = createSelector(
generationSelector, generationSelector,
(generation) => { (generation) => {
const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation; const { seamlessXAxis, seamlessYAxis } = generation;
return { shouldUseSeamless, seamlessXAxis, seamlessYAxis }; const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis);
return { activeLabel };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ParamSeamlessCollapse = () => { const ParamSeamlessCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { shouldUseSeamless } = useAppSelector(selector); const { activeLabel } = useAppSelector(selector);
const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled; const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless));
if (!isSeamlessEnabled) { if (!isSeamlessEnabled) {
return null; return null;
} }
@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => {
return ( return (
<IAICollapse <IAICollapse
label={t('parameters.seamlessTiling')} label={t('parameters.seamlessTiling')}
isOpen={shouldUseSeamless} activeLabel={activeLabel}
onToggle={handleToggle}
withSwitch
> >
<Flex sx={{ gap: 5 }}> <Flex sx={{ gap: 5 }}>
<Box flexGrow={1}> <Box flexGrow={1}>

View File

@ -1,39 +1,39 @@
import { memo } from 'react';
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParamSymmetryHorizontal from './ParamSymmetryHorizontal'; import ParamSymmetryHorizontal from './ParamSymmetryHorizontal';
import ParamSymmetryVertical from './ParamSymmetryVertical'; import ParamSymmetryVertical from './ParamSymmetryVertical';
import { useTranslation } from 'react-i18next'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useTranslation } from 'react-i18next';
import ParamSymmetryToggle from './ParamSymmetryToggle';
const selector = createSelector(
stateSelector,
(state) => ({
activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined,
}),
defaultSelectorOptions
);
const ParamSymmetryCollapse = () => { const ParamSymmetryCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const shouldUseSymmetry = useAppSelector( const { activeLabel } = useAppSelector(selector);
(state: RootState) => state.generation.shouldUseSymmetry
);
const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled; const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry));
if (!isSymmetryEnabled) { if (!isSymmetryEnabled) {
return null; return null;
} }
return ( return (
<IAICollapse <IAICollapse label={t('parameters.symmetry')} activeLabel={activeLabel}>
label={t('parameters.symmetry')}
isOpen={shouldUseSymmetry}
onToggle={handleToggle}
withSwitch
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamSymmetryToggle />
<ParamSymmetryHorizontal /> <ParamSymmetryHorizontal />
<ParamSymmetryVertical /> <ParamSymmetryVertical />
</Flex> </Flex>

View File

@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() {
return ( return (
<IAISwitch <IAISwitch
label="Enable Symmetry"
isChecked={shouldUseSymmetry} isChecked={shouldUseSymmetry}
onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))} onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))}
/> />

View File

@ -1,39 +1,42 @@
import ParamVariationWeights from './ParamVariationWeights';
import ParamVariationAmount from './ParamVariationAmount';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamVariationAmount from './ParamVariationAmount';
import { ParamVariationToggle } from './ParamVariationToggle';
import ParamVariationWeights from './ParamVariationWeights';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.generation.shouldGenerateVariations
? 'Enabled'
: undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamVariationCollapse = () => { const ParamVariationCollapse = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const shouldGenerateVariations = useAppSelector( const { activeLabel } = useAppSelector(selector);
(state: RootState) => state.generation.shouldGenerateVariations
);
const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled; const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () =>
dispatch(setShouldGenerateVariations(!shouldGenerateVariations));
if (!isVariationEnabled) { if (!isVariationEnabled) {
return null; return null;
} }
return ( return (
<IAICollapse <IAICollapse label={t('parameters.variations')} activeLabel={activeLabel}>
label={t('parameters.variations')}
isOpen={shouldGenerateVariations}
onToggle={handleToggle}
withSwitch
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}> <Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamVariationToggle />
<ParamVariationAmount /> <ParamVariationAmount />
<ParamVariationWeights /> <ParamVariationWeights />
</Flex> </Flex>

View File

@ -0,0 +1,27 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export const ParamVariationToggle = () => {
const dispatch = useAppDispatch();
const shouldGenerateVariations = useAppSelector(
(state: RootState) => state.generation.shouldGenerateVariations
);
const { t } = useTranslation();
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldGenerateVariations(e.target.checked));
return (
<IAISwitch
label="Enable Variations"
isChecked={shouldGenerateVariations}
onChange={handleChange}
/>
);
};

View File

@ -49,7 +49,6 @@ export interface GenerationState {
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
vae: VAEParam; vae: VAEParam;
shouldUseSeamless: boolean;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
} }
@ -84,9 +83,8 @@ export const initialGenerationState: GenerationState = {
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: '',
vae: '', vae: '',
shouldUseSeamless: false, seamlessXAxis: false,
seamlessXAxis: true, seamlessYAxis: false,
seamlessYAxis: true,
}; };
const initialState: GenerationState = initialGenerationState; const initialState: GenerationState = initialGenerationState;
@ -144,9 +142,6 @@ export const generationSlice = createSlice({
setImg2imgStrength: (state, action: PayloadAction<number>) => { setImg2imgStrength: (state, action: PayloadAction<number>) => {
state.img2imgStrength = action.payload; state.img2imgStrength = action.payload;
}, },
setSeamless: (state, action: PayloadAction<boolean>) => {
state.shouldUseSeamless = action.payload;
},
setSeamlessXAxis: (state, action: PayloadAction<boolean>) => { setSeamlessXAxis: (state, action: PayloadAction<boolean>) => {
state.seamlessXAxis = action.payload; state.seamlessXAxis = action.payload;
}, },
@ -268,7 +263,6 @@ export const {
modelSelected, modelSelected,
vaeSelected, vaeSelected,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamless,
setSeamlessXAxis, setSeamlessXAxis,
setSeamlessYAxis, setSeamlessYAxis,
} = generationSlice.actions; } = generationSlice.actions;

View File

@ -8,7 +8,7 @@ import { modelSelected } from 'features/parameters/store/generationSlice';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es'; import { forEach, isString } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',
@ -23,9 +23,7 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
const { data: mainModels, isLoading } = useListModelsQuery({ const { data: mainModels, isLoading } = useGetMainModelsQuery();
model_type: 'main',
});
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {

View File

@ -6,7 +6,7 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice'; import { vaeSelected } from 'features/parameters/store/generationSlice';
@ -16,9 +16,7 @@ const VAESelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({ const { data: vaeModels } = useGetVaeModelsQuery();
model_type: 'vae',
});
const selectedModelId = useAppSelector( const selectedModelId = useAppSelector(
(state: RootState) => state.generation.vae (state: RootState) => state.generation.vae

View File

@ -66,16 +66,16 @@ const tabs: InvokeTabInfo[] = [
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />, icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <NodesTab />, content: <NodesTab />,
}, },
// {
// id: 'batch',
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
// content: <BatchTab />,
// },
{ {
id: 'modelManager', id: 'modelManager',
icon: <Icon as={FaCube} sx={{ boxSize: 6, pointerEvents: 'none' }} />, icon: <Icon as={FaCube} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <ModelManagerTab />, content: <ModelManagerTab />,
}, },
// {
// id: 'batch',
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
// content: <BatchTab />,
// },
]; ];
const enabledTabsSelector = createSelector( const enabledTabsSelector = createSelector(

View File

@ -1,4 +1,4 @@
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@ -21,19 +21,25 @@ const selector = createSelector(
[uiSelector, generationSelector], [uiSelector, generationSelector],
(ui, generation) => { (ui, generation) => {
const { shouldUseSliders } = ui; const { shouldUseSliders } = ui;
const { shouldFitToWidthHeight } = generation; const { shouldFitToWidthHeight, shouldRandomizeSeed } = generation;
return { shouldUseSliders, shouldFitToWidthHeight }; const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, shouldFitToWidthHeight, activeLabel };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const ImageToImageTabCoreParameters = () => { const ImageToImageTabCoreParameters = () => {
const { shouldUseSliders, shouldFitToWidthHeight } = useAppSelector(selector); const { shouldUseSliders, shouldFitToWidthHeight, activeLabel } =
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); useAppSelector(selector);
return ( return (
<IAICollapse label={'General'} isOpen={isOpen} onToggle={onToggle}> <IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',

View File

@ -1,14 +1,15 @@
import { memo } from 'react';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import { memo } from 'react';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
const ImageToImageTabParameters = () => { const ImageToImageTabParameters = () => {
return ( return (
@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />

View File

@ -9,16 +9,14 @@ import IAISlider from 'common/components/IAISlider';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useState } from 'react'; import { useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useListModelsQuery({ const { data } = useGetMainModelsQuery();
model_type: 'main',
});
const diffusersModels = pickBy( const diffusersModels = pickBy(
data?.entities, data?.entities,

View File

@ -2,15 +2,13 @@ import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: mainModels } = useListModelsQuery({ const { data: mainModels } = useGetMainModelsQuery();
model_type: 'main',
});
const openModel = useAppSelector( const openModel = useAppSelector(
(state: RootState) => state.system.openModel (state: RootState) => state.system.openModel

View File

@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next';
import type { ChangeEvent, ReactNode } from 'react'; import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react'; import React, { useMemo, useState, useTransition } from 'react';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({ function ModelFilterButton({
label, label,
@ -36,9 +36,7 @@ function ModelFilterButton({
} }
const ModelList = () => { const ModelList = () => {
const { data: mainModels } = useListModelsQuery({ const { data: mainModels } = useGetMainModelsQuery();
model_type: 'main',
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false); const [renderModelList, setRenderModelList] = React.useState<boolean>(false);

View File

@ -1,5 +1,6 @@
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
@ -11,25 +12,30 @@ import ParamScheduler from 'features/parameters/components/Parameters/Core/Param
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react'; import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
uiSelector, stateSelector,
(ui) => { ({ ui, generation }) => {
const { shouldUseSliders } = ui; const { shouldUseSliders } = ui;
const { shouldRandomizeSeed } = generation;
return { shouldUseSliders }; const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, activeLabel };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const TextToImageTabCoreParameters = () => { const TextToImageTabCoreParameters = () => {
const { shouldUseSliders } = useAppSelector(selector); const { shouldUseSliders, activeLabel } = useAppSelector(selector);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return ( return (
<IAICollapse label={'General'} isOpen={isOpen} onToggle={onToggle}> <IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',

View File

@ -1,15 +1,16 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import { memo } from 'react'; import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => { const TextToImageTabParameters = () => {
return ( return (
@ -18,6 +19,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />

View File

@ -1,5 +1,6 @@
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
@ -12,25 +13,30 @@ import ParamScheduler from 'features/parameters/components/Parameters/Core/Param
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react'; import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
uiSelector, stateSelector,
(ui) => { ({ ui, generation }) => {
const { shouldUseSliders } = ui; const { shouldUseSliders } = ui;
const { shouldRandomizeSeed } = generation;
return { shouldUseSliders }; const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, activeLabel };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const UnifiedCanvasCoreParameters = () => { const UnifiedCanvasCoreParameters = () => {
const { shouldUseSliders } = useAppSelector(selector); const { shouldUseSliders, activeLabel } = useAppSelector(selector);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return ( return (
<IAICollapse label={'General'} isOpen={isOpen} onToggle={onToggle}> <IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',

View File

@ -1,14 +1,15 @@
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import { memo } from 'react';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
const UnifiedCanvasParameters = () => { const UnifiedCanvasParameters = () => {
return ( return (
@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />

View File

@ -1,13 +1,10 @@
export const tabMap = [ export const tabMap = [
'txt2img', 'txt2img',
'img2img', 'img2img',
// 'generate',
'unifiedCanvas', 'unifiedCanvas',
'nodes', 'nodes',
'batch',
// 'postprocessing',
// 'training',
'modelManager', 'modelManager',
'batch',
] as const; ] as const;
export type InvokeTabName = (typeof tabMap)[number]; export type InvokeTabName = (typeof tabMap)[number];

View File

@ -1,37 +1,85 @@
import { ModelsList } from 'services/api/types';
import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { keyBy } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
ControlNetModelConfig,
LoRAModelConfig,
MainModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
type ModelConfig = ModelsList['models'][number]; export type MainModelConfigEntity = MainModelConfig & { id: string };
type ListModelsArg = NonNullable< export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
paths['/api/v1/models/']['get']['parameters']['query']
>;
const modelsAdapter = createEntityAdapter<ModelConfig>({ export type ControlNetModelConfigEntity = ControlNetModelConfig & {
selectId: (model) => getModelId(model), id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string;
};
export type VaeModelConfigEntity = VaeModelConfig & { id: string };
type AnyModelConfigEntity =
| MainModelConfigEntity
| LoRAModelConfigEntity
| ControlNetModelConfigEntity
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const vaeModelsAdapter = createEntityAdapter<VaeModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
const getModelId = ({ base_model, type, name }: ModelConfig) => export const getModelId = ({ base_model, type, name }: AnyModelConfig) =>
`${base_model}/${type}/${name}`; `${base_model}/${type}/${name}`;
const createModelEntities = <T extends AnyModelConfigEntity>(
models: AnyModelConfig[]
): T[] => {
const entityArray: T[] = [];
models.forEach((model) => {
const entity = {
...cloneDeep(model),
id: getModelId(model),
} as T;
entityArray.push(entity);
});
return entityArray;
};
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
listModels: build.query<EntityState<ModelConfig>, ListModelsArg>({ getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
query: (arg) => ({ url: 'models/', params: arg }), query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boards const tags: ApiFullTagDescription[] = [
const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; { id: 'MainModel', type: LIST_TAG },
];
if (result) { if (result) {
// and individual tags for each board
tags.push( tags.push(
...result.ids.map((id) => ({ ...result.ids.map((id) => ({
type: 'Model' as const, type: 'MainModel' as const,
id, id,
})) }))
); );
@ -39,14 +87,161 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: (response: ModelsList, meta, arg) => { transformResponse: (
return modelsAdapter.setAll( response: { models: MainModelConfig[] },
modelsAdapter.getInitialState(), meta,
keyBy(response.models, getModelId) arg
) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
return mainModelsAdapter.setAll(
mainModelsAdapter.getInitialState(),
entities
);
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'LoRAModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: LoRAModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<LoRAModelConfigEntity>(
response.models
);
return loraModelsAdapter.setAll(
loraModelsAdapter.getInitialState(),
entities
);
},
}),
getControlNetModels: build.query<
EntityState<ControlNetModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'ControlNetModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: ControlNetModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(
response.models
);
return controlNetModelsAdapter.setAll(
controlNetModelsAdapter.getInitialState(),
entities
);
},
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'VaeModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: VaeModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<VaeModelConfigEntity>(
response.models
);
return vaeModelsAdapter.setAll(
vaeModelsAdapter.getInitialState(),
entities
);
},
}),
getTextualInversionModels: build.query<
EntityState<TextualInversionModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'TextualInversionModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (
response: { models: TextualInversionModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(
response.models
);
return textualInversionModelsAdapter.setAll(
textualInversionModelsAdapter.getInitialState(),
entities
); );
}, },
}), }),
}), }),
}); });
export const { useListModelsQuery } = modelsApi; export const {
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} = modelsApi;

View File

@ -2690,6 +2690,19 @@ export type components = {
model_format: components["schemas"]["LoRAModelFormat"]; model_format: components["schemas"]["LoRAModelFormat"];
error?: components["schemas"]["ModelError"]; error?: components["schemas"]["ModelError"];
}; };
/**
* LoRAModelField
* @description LoRA model field
*/
LoRAModelField: {
/**
* Model Name
* @description Name of the LoRA model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/** /**
* LoRAModelFormat * LoRAModelFormat
* @description An enumeration. * @description An enumeration.
@ -2766,10 +2779,10 @@ export type components = {
*/ */
type?: "lora_loader"; type?: "lora_loader";
/** /**
* Lora Name * Lora
* @description Lora model name * @description Lora model name
*/ */
lora_name: string; lora?: components["schemas"]["LoRAModelField"];
/** /**
* Weight * Weight
* @description With what weight to apply lora * @description With what weight to apply lora
@ -3115,7 +3128,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4448,18 +4461,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

Some files were not shown because too many files have changed in this diff Show More