Merge branch 'main' into lstein/remove-hardcoded-cuda-device

This commit is contained in:
Lincoln Stein 2023-07-05 13:38:33 -04:00 committed by GitHub
commit 9f9ce08e44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
133 changed files with 3999 additions and 4024 deletions

2
.gitignore vendored
View File

@ -201,8 +201,6 @@ checkpoints
# If it's a Mac
.DS_Store
invokeai/frontend/web/dist/*
# Let the frontend manage its own gitignore
!invokeai/frontend/web/*

View File

@ -2,17 +2,17 @@
from typing import Literal, Optional, Union
from fastapi import Query
from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
@ -86,7 +89,6 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
"/",
operation_id="update_model",
@ -109,27 +111,38 @@ async def update_model(
return model_response
@models_router.post(
"/",
"/import",
operation_id="import_model",
responses={200: {"status": "success"}},
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
model_request: ImportModelRequest
) -> None:
""" Add Model """
items_to_import = set([model_request.name])
name: str = Query(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if len(installed_models) > 0:
logger.info(f'Successfully imported {model_request.name}')
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
)
else:
logger.error(f'Model {model_request.name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found')
@models_router.delete(
"/{model_name}",

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseModel, Field
from pydantic import BaseConfig, BaseModel, Field
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
return dict(
map(
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
BaseInvocation.get_all_subclasses(),
)
)
@classmethod
def get_output_type(cls):
return signature(cls.invoke).return_annotation
@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel):
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
#fmt: off
# fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
# fmt: on
# TODO: figure out a better way to provide these hints
@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False):
"model",
"control",
"image_collection",
"vae_model",
"lora_model",
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
class InvocationConfig(BaseConfig):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -1,27 +1,28 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
import torch
from compel import Compel
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment, Conjunction,
)
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@ -51,83 +52,92 @@ class CompelInvocation(BaseInvocation):
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder:
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
text_encoder_info as text_encoder:
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int:
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
@ -146,13 +156,13 @@ def get_max_token_count(
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
get_tokens_for_prompt_object(
tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
@ -181,7 +191,7 @@ def log_tokenization_for_conjunction(
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
@ -236,7 +246,8 @@ def log_tokenization_for_prompt_object(
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@ -4,18 +4,17 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator
import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput):
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[
@ -70,16 +72,19 @@ def get_scheduler(
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation):
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
guidance_scale=self.cfg_scale,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None # v_symmetry_time_pct,
),
)
@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler,
# for ddim scheduler
eta=0.0, #ddim_eta
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
)
return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO:
#configure_model_padding(
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
#)
# )
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation):
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(
self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation):
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation):
dtype=control_model.dtype,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(
model=control_model, image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet:
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
latents: Optional[LatentsField] = Field(
description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
with unet_info as unet:
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype
)
latent, device=unet.device, dtype=latent.dtype)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device,
)
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation):
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation):
height=image_dto.height,
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation):
@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(
description="The latents to resize")
width: int = Field(
ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(
description="The latents to scale")
scale_factor: float = Field(
gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):

View File

@ -1,31 +1,38 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from pydantic import BaseModel, Field
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
@ -34,43 +41,48 @@ class VaeField(BaseModel):
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
# fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
# fmt: on
class PipelineModelField(BaseModel):
"""Pipeline model field"""
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
class LoRAModelField(BaseModel):
"""LoRA model field"""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
model: PipelineModelField = Field(description="The model to load")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Model Loader",
"tags": ["model", "loader"],
"type_hints": {
"model": "model"
}
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
@ -112,7 +124,6 @@ class PipelineModelLoaderInvocation(BaseInvocation):
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
@ -151,47 +162,66 @@ class PipelineModelLoaderInvocation(BaseInvocation):
model_type=model_type,
submodel=SubModelType.Vae,
),
)
),
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
# fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.unet is not None and any(
lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
if self.clip is not None and any(
lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput()
@ -200,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
@ -212,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
@ -221,3 +251,58 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
# fmt: on
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {"vae_model": "vae_model"},
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
)

View File

@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings):
upcase_environ = dict()
for key,value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
@ -348,7 +348,7 @@ setting environment variables INVOKEAI_<setting>.
'''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
@ -367,7 +367,8 @@ setting environment variables INVOKEAI_<setting>.
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@ -385,9 +386,9 @@ setting environment variables INVOKEAI_<setting>.
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
@ -396,7 +397,7 @@ setting environment variables INVOKEAI_<setting>.
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
'''
Update settings with contents of init file, environment, and
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
@ -411,7 +412,7 @@ setting environment variables INVOKEAI_<setting>.
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
super().parse_args(argv)
@ -431,7 +432,7 @@ setting environment variables INVOKEAI_<setting>.
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property
def root_path(self)->Path:
'''

View File

@ -33,13 +33,13 @@ class ModelManagerServiceBase(ABC):
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
pass
@abstractmethod
def get_model(
self,
@ -50,8 +50,8 @@ class ModelManagerServiceBase(ABC):
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@ -115,8 +115,8 @@ class ModelManagerServiceBase(ABC):
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@ -129,12 +129,35 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
"""
@ -153,7 +176,7 @@ class ModelManagerService(ModelManagerServiceBase):
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
@ -183,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
@ -238,7 +263,7 @@ class ModelManagerService(ModelManagerServiceBase):
submodel=submodel,
model_info=model_info
)
return model_info
def model_exists(
@ -291,8 +316,8 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
@ -305,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.mgr.del_model(model_name, base_model, model_type)
@ -360,4 +385,25 @@ class ModelManagerService(ModelManagerServiceBase):
@property
def logger(self):
return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)

View File

@ -76,6 +76,10 @@ class MigrateTo3(object):
Create a unique name for a model for use within models.yaml.
'''
done = False
# some model names have slashes in them, which really screws things up
name = name.replace('/','_')
key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key
counter = 1
@ -219,11 +223,11 @@ class MigrateTo3(object):
repo_id = 'openai/clip-vit-large-patch14'
self._migrate_pretrained(CLIPTokenizer,
repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
dest= target_dir / 'clip-vit-large-patch14',
**kwargs)
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
dest = target_dir / 'clip-vit-large-patch14',
**kwargs)
# sd-2

View File

@ -18,7 +18,7 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from ..util.logging import InvokeAILogger
@ -166,17 +166,22 @@ class ModelInstall(object):
# add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path)
self.heuristic_import(path)
job += 1
self.mgr.commit()
def heuristic_install(self,
def heuristic_import(self,
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]:
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
'''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
if not models_installed:
models_installed = set()
models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
@ -185,24 +190,24 @@ class ModelInstall(object):
try:
# checkpoint file, or similar
if path.is_file():
models_installed.add(self._install_path(path))
models_installed.update(self._install_path(path))
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
models_installed.add(self._install_path(path))
models_installed.update(self._install_path(path))
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_install(child, models_installed=models_installed)
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(path).split('/')) == 2:
models_installed.add(self._install_repo(str(path)))
models_installed.update(self._install_repo(str(path)))
# a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.add(self._install_url(model_path_id_or_url))
models_installed.update(self._install_url(model_path_id_or_url))
else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
@ -214,24 +219,25 @@ class ModelInstall(object):
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
try:
# logger.debug(f'Probing {path}')
model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
model_result = self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return path
return {}
return {str(path): model_result}
def _install_url(self, url: str)->Path:
def _install_url(self, url: str)->dict:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
@ -244,7 +250,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->Path:
def _install_repo(self, repo_id: str)->dict:
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically

View File

@ -1,7 +1,7 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo
from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager
from .model_cache import ModelCache
from picklescan.scanner import scan_file_path
from .models import BaseModelType, ModelVariantType
try:
@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint = load_file(checkpoint_path)
else:
if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path)
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not

View File

@ -1,18 +1,17 @@
from __future__ import annotations
import copy
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase:
#rank: Optional[int]
@ -539,9 +538,10 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
#with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
@ -655,6 +655,9 @@ class TextualInversionModel:
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")

View File

@ -233,14 +233,14 @@ import hashlib
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, Set, Callable, types
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree
import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel
from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase,
ModelConfigBase, ModelNotFoundException,
)
# We are only starting to number the config file with release 3.
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
"Raised when an invalid model is requested"
pass
MAX_CACHE_SIZE = 6.0 # GB
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
MAX_CACHE_SIZE = 6.0 # GB
class ConfigMeta(BaseModel):
version: str
@ -404,7 +409,7 @@ class ModelManager(object):
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise Exception(f"Model not found - {model_key}")
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key]
model_path = self.app_config.root_path / model_config.path
@ -416,7 +421,7 @@ class ModelManager(object):
else:
self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}")
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
@ -571,13 +576,16 @@ class ModelManager(object):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> None:
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
The returned dict has the same format as the dict returned by
model_info().
"""
model_class = MODEL_CLASSES[base_model][model_type]
@ -601,12 +609,18 @@ class ModelManager(object):
old_model_cache.unlink()
# remove in-memory cache
# note: it not garantie to release memory(model can has other references)
# note: it not guaranteed to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models[model_key] = model_config
return AddModelResult(
name = model_name,
model_type = model_type,
base_model = base_model,
config = model_config,
)
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
@ -717,19 +731,19 @@ class ModelManager(object):
if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path)
try:
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
except NotImplementedError as e:
self.logger.warning(e)
try:
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
except NotImplementedError as e:
self.logger.warning(e)
imported_models = self.autoimport()
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self)->set[Path]:
def autoimport(self)->Dict[str, AddModelResult]:
'''
Scan the autoimport directory (if defined) and import new models, delete defunct models.
'''
@ -742,7 +756,6 @@ class ModelManager(object):
prediction_type_helper = ask_user_for_prediction_type,
)
installed = set()
scanned_dirs = set()
config = self.app_config
@ -756,13 +769,14 @@ class ModelManager(object):
continue
self.logger.info(f'Scanning {autodir} for models to import')
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = set()
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
@ -772,7 +786,7 @@ class ModelManager(object):
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
new_models_found.update(installer.heuristic_install(path))
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
for f in files:
@ -780,7 +794,7 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_install(path))
new_models_found.update(installer.heuristic_import(path))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
@ -790,7 +804,7 @@ class ModelManager(object):
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]:
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -803,17 +817,20 @@ class ModelManager(object):
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set()
successfully_installed = dict()
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
for thing in items_to_import:
try:
installed = installer.heuristic_install(thing)
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"

View File

@ -8,6 +8,7 @@ from .base import (
ModelType,
SubModelType,
classproperty,
ModelNotFoundException,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
checkpoint_path = self.model_path
if os.path.isdir(checkpoint_path):
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
if not os.path.exists(checkpoint_path):
raise ModelNotFoundException()
model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path,
file_path=checkpoint_path,
dtype=torch_dtype,
)

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
</head>
<body dir="ltr">

View File

@ -24,16 +24,13 @@
},
"common": {
"hotkeysLabel": "Hotkeys",
"themeLabel": "Theme",
"darkMode": "Dark Mode",
"lightMode": "Light Mode",
"languagePickerLabel": "Language",
"reportBugLabel": "Report Bug",
"githubLabel": "Github",
"discordLabel": "Discord",
"settingsLabel": "Settings",
"darkTheme": "Dark",
"lightTheme": "Light",
"greenTheme": "Green",
"oceanTheme": "Ocean",
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
@ -55,6 +52,7 @@
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Node Editor",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing",
@ -336,6 +334,7 @@
"modelManager": {
"modelManager": "Model Manager",
"model": "Model",
"vae": "VAE",
"allModels": "All Models",
"checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers",
@ -351,6 +350,7 @@
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"baseModel": "Base Model",
"name": "Name",
"nameValidationMsg": "Enter a name for your model",
"description": "Description",
@ -363,6 +363,7 @@
"repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width",
@ -524,7 +525,8 @@
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
},
"settings": {
"models": "Models",
@ -547,7 +549,8 @@
"general": "General",
"generation": "Generation",
"ui": "User Interface",
"availableSchedulers": "Available Schedulers"
"favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
},
"toast": {
"serverError": "Server Error",

View File

@ -67,6 +67,7 @@
"@fontsource-variable/inter": "^5.0.3",
"@fontsource/inter": "^5.0.3",
"@mantine/core": "^6.0.14",
"@mantine/form": "^6.0.15",
"@mantine/hooks": "^6.0.14",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",

View File

@ -53,6 +53,7 @@
"linear": "Linear",
"nodes": "Node Editor",
"batch": "Batch Manager",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing",
@ -334,6 +335,7 @@
"modelManager": {
"modelManager": "Model Manager",
"model": "Model",
"vae": "VAE",
"allModels": "All Models",
"checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers",
@ -349,6 +351,7 @@
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"baseModel": "Base Model",
"name": "Name",
"nameValidationMsg": "Enter a name for your model",
"description": "Description",
@ -361,6 +364,7 @@
"repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width",

View File

@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n';
import { ReactNode, memo, useEffect } from 'react';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
const DEFAULT_CONFIG = {};

View File

@ -1,4 +1,8 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd';
@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = {
},
};
const selector = createSelector(
stateSelector,
(state) => {
const gallerySelectionCount = state.gallery.selection.length;
const batchSelectionCount = state.batch.selection.length;
return {
gallerySelectionCount,
batchSelectionCount,
};
},
defaultSelectorOptions
);
const DragPreview = (props: OverlayDragImageProps) => {
const { gallerySelectionCount, batchSelectionCount } =
useAppSelector(selector);
if (!props.dragData) {
return;
}
@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
);
}
if (props.dragData.payloadType === 'IMAGE_NAMES') {
if (props.dragData.payloadType === 'BATCH_SELECTION') {
return (
<Flex
sx={{
@ -70,7 +91,26 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES,
}}
>
<Heading>{props.dragData.payload.imageNames.length}</Heading>
<Heading>{batchSelectionCount}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{gallerySelectionCount}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);

View File

@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO };
};
export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'IMAGE_NAMES';
payload: { imageNames: string[] };
export type GallerySelectionDraggableData = BaseDragData & {
payloadType: 'GALLERY_SELECTION';
};
export type BatchSelectionDraggableData = BaseDragData & {
payloadType: 'BATCH_SELECTION';
};
export type TypesafeDraggableData =
| ImageDraggableData
| ImageNamesDraggableData;
| GallerySelectionDraggableData
| BatchSelectionDraggableData;
interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
@ -155,11 +159,13 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION';
case 'MOVE_BOARD':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
return (
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
);
default:
return false;
}

View File

@ -20,10 +20,8 @@ const serializationDenylist: {
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
controlNet: controlNetDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {

View File

@ -1,21 +1,21 @@
import { startAppListening } from '..';
import { imageDeleted } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import {
imageSelected,
imageRemoved,
selectImagesIds,
} from 'features/gallery/store/gallerySlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { api } from 'services/api';
import {
imageRemoved,
imageSelected,
selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
import {
imageDeletionConfirmed,
isModalOpenChanged,
} from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api';
import { imageDeleted } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' });
@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => {
state.gallery.selection[state.gallery.selection.length - 1];
if (lastSelectedImage === image_name) {
const ids = selectImagesIds(state);
const filteredImages = selectFilteredImages(state);
const ids = filteredImages.map((i) => i.image_name);
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name

View File

@ -1,24 +1,23 @@
import { createAction } from '@reduxjs/toolkit';
import { startAppListening } from '../';
import { log } from 'app/logging/useLogger';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { log } from 'app/logging/useLogger';
import {
imageAddedToBatch,
imagesAddedToBatch,
} from 'features/batch/store/batchSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
fieldValueChanged,
imageCollectionFieldValueChanged,
} from 'features/nodes/store/nodesSlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { startAppListening } from '../';
const moduleLog = log.child({ namespace: 'dnd' });
@ -33,6 +32,7 @@ export const addImageDroppedListener = () => {
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
const { actionType } = overData;
const state = getState();
// set current image
if (
@ -64,9 +64,9 @@ export const addImageDroppedListener = () => {
// add multiple images to batch
if (
actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES'
activeData.payloadType === 'GALLERY_SELECTION'
) {
dispatch(imagesAddedToBatch(activeData.payload.imageNames));
dispatch(imagesAddedToBatch(state.gallery.selection));
}
// set control image
@ -128,14 +128,14 @@ export const addImageDroppedListener = () => {
// set multiple nodes images (multiple images handler)
if (
actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_NAMES'
activeData.payloadType === 'GALLERY_SELECTION'
) {
const { fieldName, nodeId } = overData.context;
dispatch(
imageCollectionFieldValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageNames.map((image_name) => ({
value: state.gallery.selection.map((image_name) => ({
image_name,
})),
})

View File

@ -8,31 +8,32 @@ import {
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import batchReducer from 'features/batch/store/batchSlice';
import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import batchReducer from 'features/batch/store/batchSlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { api } from 'services/api';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { api } from 'services/api';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
const allReducers = {
canvas: canvasReducer,
@ -50,6 +51,7 @@ const allReducers = {
dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer,
imageDeletion: imageDeletionReducer,
lora: loraReducer,
[api.reducerPath]: api.reducer,
};
@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'controlNet',
'dynamicPrompts',
'batch',
'lora',
// 'boards',
// 'hotkeys',
// 'config',

View File

@ -4,22 +4,25 @@ import {
Collapse,
Flex,
Spacer,
Switch,
Text,
useColorMode,
useDisclosure,
} from '@chakra-ui/react';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo } from 'react';
import { mode } from 'theme/util/mode';
export type IAIToggleCollapseProps = PropsWithChildren & {
label: string;
isOpen: boolean;
onToggle: () => void;
withSwitch?: boolean;
activeLabel?: string;
defaultIsOpen?: boolean;
};
const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props;
const { label, activeLabel, children, defaultIsOpen = false } = props;
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen });
const { colorMode } = useColorMode();
return (
<Box>
<Flex
@ -28,6 +31,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
alignItems: 'center',
p: 2,
px: 4,
gap: 2,
borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen
@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
}}
>
{label}
<AnimatePresence>
{activeLabel && (
<motion.div
key="statusText"
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<Text
sx={{ color: 'accent.500', _dark: { color: 'accent.300' } }}
>
{activeLabel}
</Text>
</motion.div>
)}
</AnimatePresence>
<Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
{!withSwitch && (
<ChevronUpIcon
sx={{
w: '1rem',
h: '1rem',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
)}
<ChevronUpIcon
sx={{
w: '1rem',
h: '1rem',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
</Flex>
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box

View File

@ -61,7 +61,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&:disabled': {
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
},

View File

@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&:disabled': {
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
},

View File

@ -36,7 +36,6 @@ const IAISwitch = (props: Props) => {
isDisabled={isDisabled}
width={width}
display="flex"
gap={4}
alignItems="center"
{...formControlProps}
>
@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => {
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
pe: 4,
}}
{...formLabelProps}
>

View File

@ -1,28 +1,29 @@
import { Box, Icon, Skeleton } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FaExclamationCircle } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import {
batchImageRangeEndSelected,
batchImageSelected,
batchImageSelectionToggled,
imageRemovedFromBatch,
} from 'features/batch/store/batchSlice';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaExclamationCircle } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const isSelectedSelector = createSelector(
[stateSelector, (state: RootState, imageName: string) => imageName],
(state, imageName) => ({
selection: state.batch.selection,
isSelected: state.batch.selection.includes(imageName),
}),
defaultSelectorOptions
);
const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
(state) => ({
selectionCount: state.batch.selection.length,
isSelected: state.batch.selection.includes(image_name),
}),
defaultSelectorOptions
);
type BatchImageProps = {
imageName: string;
@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => {
} = useGetImageDTOQuery(props.imageName);
const dispatch = useAppDispatch();
const { isSelected, selection } = useAppSelector((state) =>
isSelectedSelector(state, props.imageName)
const selector = useMemo(
() => makeSelector(props.imageName),
[props.imageName]
);
const { isSelected, selectionCount } = useAppSelector(selector);
const handleClickRemove = useCallback(() => {
dispatch(imageRemovedFromBatch(props.imageName));
}, [dispatch, props.imageName]);
@ -59,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => {
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) {
if (selectionCount > 1) {
return {
id: 'batch',
payloadType: 'IMAGE_NAMES',
payload: {
imageNames: selection,
},
payloadType: 'BATCH_SELECTION',
};
}
@ -76,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => {
payload: { imageDTO },
};
}
}, [imageDTO, selection]);
}, [imageDTO, selectionCount]);
if (isError) {
return <Icon as={FaExclamationCircle} />;

View File

@ -1,25 +1,22 @@
import { memo, useCallback, useMemo, useState } from 'react';
import { ImageDTO } from 'services/api/types';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
@ -83,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => {
}
}, [controlImage, controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => {
if (controlNetId) {
return {
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
};
}
}, [controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
}),
[controlNetId]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.controlNet;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamControlNetFeatureToggle = () => {
const { isEnabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(isControlNetEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
/>
);
};
export default ParamControlNetFeatureToggle;

View File

@ -0,0 +1,15 @@
import { filter } from 'lodash-es';
import { ControlNetConfig } from '../store/controlNetSlice';
export const getValidControlNets = (
controlNets: Record<string, ControlNetConfig>
) => {
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
return validControlNets;
};

View File

@ -1,40 +1,30 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
return { activeLabel: isEnabled ? 'Enabled' : undefined };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
const { activeLabel } = useAppSelector(selector);
return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<IAICollapse label="Dynamic Prompts" activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsToggle />
<ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts />
</Flex>

View File

@ -1,23 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { combinatorialToggled } from '../store/slice';
const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;
const { combinatorial, isEnabled } = state.dynamicPrompts;
return { combinatorial };
return { combinatorial, isDisabled: !isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const { combinatorial, isDisabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => {
return (
<IAISwitch
isDisabled={isDisabled}
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}

View File

@ -0,0 +1,36 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsToggle = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAISwitch
label="Enable Dynamic Prompts"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
);
};
export default ParamDynamicPromptsToggle;

View File

@ -1,25 +1,31 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useCallback } from 'react';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts, combinatorial } = state.dynamicPrompts;
const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
return {
maxPrompts,
min,
sliderMax,
inputMax,
isDisabled: !isEnabled || !combinatorial,
};
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
const { maxPrompts, min, sliderMax, inputMax, isDisabled } =
useAppSelector(selector);
const dispatch = useAppDispatch();
@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => {
return (
<IAISlider
label="Max Prompts"
isDisabled={!combinatorial}
isDisabled={isDisabled}
min={min}
max={sliderMax}
value={maxPrompts}

View File

@ -1,19 +1,19 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { isEqual } from 'lodash-es';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useMemo } from 'react';
import IAIDndImage from 'common/components/IAIDndImage';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { stateSelector } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector(
[stateSelector, selectLastSelectedImage],

View File

@ -1,34 +1,35 @@
import { Box } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import ImageContextMenu from './ImageContextMenu';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import {
imageRangeEndSelected,
imageSelected,
imageSelectionToggled,
} from '../store/gallerySlice';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import ImageContextMenu from './ImageContextMenu';
export const selector = createSelector(
[stateSelector, (state, { image_name }: ImageDTO) => image_name],
({ gallery }, image_name) => {
const isSelected = gallery.selection.includes(image_name);
const selection = gallery.selection;
return {
isSelected,
selection,
};
},
defaultSelectorOptions
);
export const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
({ gallery }) => {
const isSelected = gallery.selection.includes(image_name);
const selectionCount = gallery.selection.length;
return {
isSelected,
selectionCount,
};
},
defaultSelectorOptions
);
interface HoverableImageProps {
imageDTO: ImageDTO;
@ -38,13 +39,13 @@ interface HoverableImageProps {
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const GalleryImage = (props: HoverableImageProps) => {
const { isSelected, selection } = useAppSelector((state) =>
selector(state, props.imageDTO)
);
const { imageDTO } = props;
const { image_url, thumbnail_url, image_name } = imageDTO;
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
const { isSelected, selectionCount } = useAppSelector(localSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -74,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => {
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) {
if (selectionCount > 1) {
return {
id: 'gallery-image',
payloadType: 'IMAGE_NAMES',
payload: { imageNames: selection },
payloadType: 'GALLERY_SELECTION',
};
}
@ -89,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => {
payload: { imageDTO },
};
}
}, [imageDTO, selection]);
}, [imageDTO, selectionCount]);
return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>

View File

@ -7,7 +7,6 @@ import {
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { dateComparator } from 'common/util/dateComparator';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice';
import { keyBy, uniq } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import {
@ -174,11 +173,6 @@ export const gallerySlice = createSlice({
state.limit = limit;
state.total = total;
});
builder.addCase(imageDeletionConfirmed, (state, action) => {
// Image deleted
const { image_name } = action.payload.imageDTO;
imagesAdapter.removeOne(state, image_name);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;

View File

@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store';
import {
imageDeletionConfirmed,
imageToDeleteCleared,
isModalOpenChanged,
selectImageUsage,
} from '../store/imageDeletionSlice';
@ -63,6 +64,7 @@ const DeleteImageModal = () => {
const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared());
dispatch(isModalOpenChanged(false));
}, [dispatch]);
const handleDelete = useCallback(() => {

View File

@ -31,6 +31,7 @@ const imageDeletion = createSlice({
},
imageToDeleteCleared: (state) => {
state.imageToDelete = null;
state.isModalOpen = false;
},
},
});

View File

@ -0,0 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa';
import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice';
type Props = {
lora: Lora;
};
const ParamLora = (props: Props) => {
const dispatch = useAppDispatch();
const { lora } = props;
const handleChange = useCallback(
(v: number) => {
dispatch(loraWeightChanged({ id: lora.id, weight: v }));
},
[dispatch, lora.id]
);
const handleReset = useCallback(() => {
dispatch(loraWeightChanged({ id: lora.id, weight: 1 }));
}, [dispatch, lora.id]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]);
return (
<Flex sx={{ gap: 2.5, alignItems: 'flex-end' }}>
<IAISlider
label={lora.name}
value={lora.weight}
onChange={handleChange}
min={-1}
max={2}
step={0.01}
withInput
withReset
handleReset={handleReset}
withSliderMarks
sliderMarks={[-1, 0, 1, 2]}
/>
<IAIIconButton
size="sm"
onClick={handleRemoveLora}
tooltip="Remove LoRA"
aria-label="Remove LoRA"
icon={<FaTrash />}
colorScheme="error"
/>
</Flex>
);
};
export default memo(ParamLora);

View File

@ -0,0 +1,36 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { size } from 'lodash-es';
import { memo } from 'react';
import ParamLoraList from './ParamLoraList';
import ParamLoraSelect from './ParamLoraSelect';
const selector = createSelector(
stateSelector,
(state) => {
const loraCount = size(state.lora.loras);
return {
activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined,
};
},
defaultSelectorOptions
);
const ParamLoraCollapse = () => {
const { activeLabel } = useAppSelector(selector);
return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoraSelect />
<ParamLoraList />
</Flex>
</IAICollapse>
);
};
export default memo(ParamLoraCollapse);

View File

@ -0,0 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import ParamLora from './ParamLora';
const selector = createSelector(
stateSelector,
({ lora }) => {
const { loras } = lora;
return { loras };
},
defaultSelectorOptions
);
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
return map(loras, (lora) => <ParamLora key={lora.name} lora={lora} />);
};
export default ParamLoraList;

View File

@ -0,0 +1,107 @@
import { Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { forEach } from 'lodash-es';
import { forwardRef, useCallback, useMemo } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { loraAdded } from '../store/loraSlice';
type LoraSelectItem = {
label: string;
value: string;
description?: string;
};
const selector = createSelector(
stateSelector,
({ lora }) => ({
loras: lora.loras,
}),
defaultSelectorOptions
);
const ParamLoraSelect = () => {
const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector);
const { data: lorasQueryData } = useGetLoRAModelsQuery();
const data = useMemo(() => {
if (!lorasQueryData) {
return [];
}
const data: LoraSelectItem[] = [];
forEach(lorasQueryData.entities, (lora, id) => {
if (!lora || Boolean(id in loras)) {
return;
}
data.push({
value: id,
label: lora.name,
description: lora.description,
});
});
return data;
}, [loras, lorasQueryData]);
const handleChange = useCallback(
(v: string[]) => {
const loraEntity = lorasQueryData?.entities[v[0]];
if (!loraEntity) {
return;
}
v[0] && dispatch(loraAdded(loraEntity));
},
[dispatch, lorasQueryData?.entities]
);
return (
<IAIMantineMultiSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching LoRAs"
itemComponent={SelectItem}
disabled={data.length === 0}
filter={(value, selected, item: LoraSelectItem) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={handleChange}
/>
);
};
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description?: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default ParamLoraSelect;

View File

@ -0,0 +1,46 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type Lora = {
id: string;
name: string;
weight: number;
};
export const defaultLoRAConfig: Omit<Lora, 'id' | 'name'> = {
weight: 1,
};
export type LoraState = {
loras: Record<string, Lora>;
};
export const intialLoraState: LoraState = {
loras: {},
};
export const loraSlice = createSlice({
name: 'lora',
initialState: intialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { name, id } = action.payload;
state.loras[id] = { id, name, ...defaultLoRAConfig };
},
loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload;
delete state.loras[id];
},
loraWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
) => {
const { id, weight } = action.payload;
state.loras[id].weight = weight;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -3,20 +3,22 @@ import { memo } from 'react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
@ -152,6 +154,26 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'lora_model' && template.type === 'lora_model') {
return (
<LoRAModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent

View File

@ -7,18 +7,16 @@ import {
} from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react';
import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api/types';
import { Flex } from '@chakra-ui/react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
NodesImageDropData,
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIDndImage from 'common/components/IAIDndImage';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
@ -34,23 +32,6 @@ const ImageInputFieldComponent = (
isSuccess,
} = useGetImageDTOQuery(field.value?.image_name ?? skipToken);
const handleDrop = useCallback(
({ image_name }: ImageDTO) => {
if (field.value?.image_name === image_name) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: { image_name },
})
);
},
[dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => {
dispatch(
fieldValueChanged({
@ -71,15 +52,14 @@ const ImageInputFieldComponent = (
}
}, [field.name, imageDTO, nodeId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => {
if (imageDTO) {
return {
id: `node-${nodeId}-${field.name}`,
actionType: 'SET_NODES_IMAGE',
context: { nodeId, fieldName: field.name },
};
}
}, [field.name, imageDTO, nodeId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: `node-${nodeId}-${field.name}`,
actionType: 'SET_NODES_IMAGE',
context: { nodeId, fieldName: field.name },
}),
[field.name, nodeId]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({

View File

@ -0,0 +1,102 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const LoRAModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: loraModels } = useGetLoRAModelsQuery();
const selectedModel = useMemo(
() => loraModels?.entities[field.value ?? loraModels.ids[0]],
[loraModels?.entities, loraModels?.ids, field.value]
);
const data = useMemo(() => {
if (!loraModels) {
return [];
}
const data: SelectItem[] = [];
forEach(loraModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [loraModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && loraModels?.ids.includes(field.value)) {
return;
}
const firstLora = loraModels?.ids[0];
if (!isString(firstLora)) {
return;
}
handleValueChanged(firstLora);
}, [field.value, handleValueChanged, loraModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(LoRAModelInputFieldComponent);

View File

@ -6,13 +6,13 @@ import {
ModelInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -22,18 +22,16 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'main',
});
const { data: mainModels } = useGetMainModelsQuery();
const data = useMemo(() => {
if (!pipelineModels) {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
@ -46,11 +44,11 @@ const ModelInputFieldComponent = (
});
return data;
}, [pipelineModels]);
}, [mainModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
[mainModels?.entities, mainModels?.ids, field.value]
);
const handleValueChanged = useCallback(
@ -71,18 +69,18 @@ const ModelInputFieldComponent = (
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
if (field.value && mainModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
}, [field.value, handleValueChanged, mainModels?.ids]);
return (
<IAIMantineSelect

View File

@ -0,0 +1,95 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const VaeModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useGetVaeModelsQuery();
const selectedModel = useMemo(
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
[vaeModels?.entities, vaeModels?.ids, field.value]
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && vaeModels?.ids.includes(field.value)) {
return;
}
handleValueChanged('auto');
}, [field.value, handleValueChanged, vaeModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(VaeModelInputFieldComponent);

View File

@ -1,5 +1,8 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { cloneDeep, uniqBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import {
addEdge,
applyEdgeChanges,
@ -11,12 +14,9 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { ImageField } from 'services/api/types';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store';
import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es';
export type NodesState = {
nodes: Node<InvocationValue>[];

View File

@ -17,6 +17,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip',
VaeField: 'vae',
model: 'model',
vae_model: 'vae_model',
lora_model: 'lora_model',
array: 'array',
item: 'item',
ColorField: 'color',
@ -116,6 +118,18 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model',
description: 'Models are models.',
},
vae_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'VAE',
description: 'Models are models.',
},
lora_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'LoRA',
description: 'Models are models.',
},
array: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -64,6 +64,8 @@ export type FieldType =
| 'vae'
| 'control'
| 'model'
| 'vae_model'
| 'lora_model'
| 'array'
| 'item'
| 'color'
@ -91,6 +93,8 @@ export type InputFieldValue =
| ControlInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@ -116,6 +120,8 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@ -228,6 +234,16 @@ export type ModelInputFieldValue = FieldValueBase & {
value?: string;
};
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
value?: string;
};
export type LoRAModelInputFieldValue = FieldValueBase & {
type: 'lora_model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@ -305,6 +321,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
type: 'conditioning';
};
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'unet';
};
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'clip';
};
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'vae';
};
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'control';
@ -322,6 +353,16 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'vae_model';
};
export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'lora_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { filter } from 'lodash-es';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = (
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
const validControlNets = getValidControlNets(controlNets);
if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (validControlNets.length > 1) {

View File

@ -3,27 +3,29 @@ import { OpenAPIV3 } from 'openapi-types';
import { FIELD_TYPE_MAP } from '../types/constants';
import { isSchemaObject } from '../types/typeGuards';
import {
BooleanInputFieldTemplate,
EnumInputFieldTemplate,
FloatInputFieldTemplate,
ImageInputFieldTemplate,
IntegerInputFieldTemplate,
LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
UNetInputFieldTemplate,
ClipInputFieldTemplate,
VaeInputFieldTemplate,
ControlInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate,
ItemInputFieldTemplate,
BooleanInputFieldTemplate,
ClipInputFieldTemplate,
ColorInputFieldTemplate,
InputFieldTemplateBase,
OutputFieldTemplate,
TypeHints,
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate,
InputFieldTemplateBase,
IntegerInputFieldTemplate,
ItemInputFieldTemplate,
LatentsInputFieldTemplate,
LoRAModelInputFieldTemplate,
ModelInputFieldTemplate,
OutputFieldTemplate,
StringInputFieldTemplate,
TypeHints,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
} from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -175,6 +177,36 @@ const buildModelInputFieldTemplate = ({
return template;
};
const buildVaeModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
const template: VaeModelInputFieldTemplate = {
...baseField,
type: 'vae_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LoRAModelInputFieldTemplate => {
const template: LoRAModelInputFieldTemplate = {
...baseField,
type: 'lora_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@ -441,6 +473,12 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -75,6 +75,14 @@ export const buildInputFieldValue = (
if (template.type === 'model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
}
return fieldValue;

View File

@ -0,0 +1,148 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import { LoraLoaderInvocation } from 'services/api/types';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import {
LORA_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
} from './constants';
export const addLoRAsToGraph = (
graph: NonNullableGraph,
state: RootState,
baseNodeId: string
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
if (loraCount > 0) {
// remove any existing connections from main model loader, we need to insert the lora nodes
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet', 'clip'].includes(e.source.field)
)
);
}
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { id, name, weight } = lora;
const loraField = modelIdToLoRAModelField(id);
const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace(
'.',
'_'
)}`;
const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader',
id: currentLoraNodeId,
lora: loraField,
weight,
};
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});
}
// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};

View File

@ -0,0 +1,68 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
INPAINT,
INPAINT_GRAPH,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
export const addVAEToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { vae: vaeId } = state.generation;
const vae_model = modelIdToVAEModelField(vaeId);
if (vaeId !== 'auto') {
graph.nodes[VAE_LOADER] = {
type: 'vae_loader',
id: VAE_LOADER,
vae_model,
};
}
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
});
}
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
if (graph.id === INPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
});
}
};

View File

@ -1,31 +1,27 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
ITERATE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE,
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -52,7 +48,7 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -81,9 +77,9 @@ export const buildCanvasImageToImageGraph = (
type: 'noise',
id: NOISE,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -110,7 +106,7 @@ export const buildCanvasImageToImageGraph = (
edges: [
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -120,7 +116,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -128,16 +124,6 @@ export const buildCanvasImageToImageGraph = (
field: 'clip',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -170,17 +156,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -277,6 +253,11 @@ export const buildCanvasImageToImageGraph = (
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -1,23 +1,25 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
InpaintInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
INPAINT,
INPAINT_GRAPH,
ITERATE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
INPAINT_GRAPH,
INPAINT,
} from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@ -55,7 +57,7 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
@ -101,9 +103,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[RANGE_OF_SIZE]: {
@ -142,7 +144,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -152,7 +154,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -162,7 +164,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -170,16 +172,6 @@ export const buildCanvasInpaintGraph = (
field: 'unet',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
@ -203,6 +195,11 @@ export const buildCanvasInpaintGraph = (
],
};
addLoRAsToGraph(graph, state, INPAINT);
// Add VAE
addVAEToGraph(graph, state);
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed

View File

@ -1,21 +1,19 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
ITERATE,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/**
* Builds the Canvas tab's Text to Image graph.
@ -38,7 +36,7 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -76,9 +74,9 @@ export const buildCanvasTextToImageGraph = (
scheduler,
steps,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -109,7 +107,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -119,7 +117,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -129,7 +127,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -147,16 +145,6 @@ export const buildCanvasTextToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: NOISE,
@ -170,6 +158,11 @@ export const buildCanvasTextToImageGraph = (
],
};
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -1,28 +1,30 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageCollectionInvocation,
ImageResizeInvocation,
ImageToLatentsInvocation,
IterateInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -69,7 +71,7 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state');
}
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
@ -89,9 +91,9 @@ export const buildLinearImageToImageGraph = (
type: 'noise',
id: NOISE,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -118,7 +120,7 @@ export const buildLinearImageToImageGraph = (
edges: [
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -128,7 +130,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -136,16 +138,6 @@ export const buildLinearImageToImageGraph = (
field: 'clip',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -176,19 +168,10 @@ export const buildLinearImageToImageGraph = (
field: 'noise',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -323,6 +306,11 @@ export const buildLinearImageToImageGraph = (
});
}
addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS);
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -1,17 +1,19 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
export const buildLinearTextToImageGraph = (
state: RootState
@ -27,7 +29,7 @@ export const buildLinearTextToImageGraph = (
height,
} = state.generation;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -65,9 +67,9 @@ export const buildLinearTextToImageGraph = (
scheduler,
steps,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -98,7 +100,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -108,7 +110,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -118,7 +120,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -136,16 +138,6 @@ export const buildLinearTextToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: NOISE,
@ -159,6 +151,11 @@ export const buildLinearTextToImageGraph = (
],
};
addLoRAsToGraph(graph, state, TEXT_TO_LATENTS);
// Add Custom VAE Support
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -1,10 +1,12 @@
import { Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { v4 as uuidv4 } from 'uuid';
import { modelIdToLoRAModelField } from '../modelIdToLoRAName';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/**
* We need to do special handling for some fields
@ -27,7 +29,19 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
return modelIdToMainModelField(field.value);
}
}
if (field.type === 'vae_model') {
if (field.value) {
return modelIdToVAEModelField(field.value);
}
}
if (field.type === 'lora_model') {
if (field.value) {
return modelIdToLoRAModelField(field.value);
}
}

View File

@ -7,7 +7,9 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader';
export const LORA_LOADER = 'lora_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';

View File

@ -0,0 +1,12 @@
import { BaseModelType, LoRAModelField } from 'services/api/types';
export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => {
const [base_model, model_type, model_name] = loraId.split('/');
const field: LoRAModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -0,0 +1,16 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,18 +0,0 @@
import { BaseModelType, PipelineModelField } from 'services/api/types';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -0,0 +1,16 @@
import { BaseModelType, VAEModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: VAEModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,20 +1,15 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
import { useTranslation } from 'react-i18next';
import ParamBoundingBoxHeight from './ParamBoundingBoxHeight';
import ParamBoundingBoxWidth from './ParamBoundingBoxWidth';
const ParamBoundingBoxCollapse = () => {
const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return (
<IAICollapse
label={t('parameters.boundingBoxHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<IAICollapse label={t('parameters.boundingBoxHeader')}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamBoundingBoxWidth />
<ParamBoundingBoxHeight />

View File

@ -1,4 +1,4 @@
import { Flex, useDisclosure } from '@chakra-ui/react';
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse';
import ParamInfillMethod from './ParamInfillMethod';
import ParamInfillTilesize from './ParamInfillTilesize';
import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing';
import ParamScaledWidth from './ParamScaledWidth';
import ParamScaledHeight from './ParamScaledHeight';
import ParamScaledWidth from './ParamScaledWidth';
const ParamInfillCollapse = () => {
const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return (
<IAICollapse
label={t('parameters.infillScalingHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<IAICollapse label={t('parameters.infillScalingHeader')}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamInfillMethod />
<ParamInfillTilesize />

View File

@ -1,22 +1,16 @@
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamBlur from './ParamSeamBlur';
import ParamSeamSize from './ParamSeamSize';
import ParamSeamSteps from './ParamSeamSteps';
import ParamSeamStrength from './ParamSeamStrength';
import { useDisclosure } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
const ParamSeamCorrectionCollapse = () => {
const { t } = useTranslation();
const { isOpen, onToggle } = useDisclosure();
return (
<IAICollapse
label={t('parameters.seamCorrectionHeader')}
isOpen={isOpen}
onToggle={onToggle}
>
<IAICollapse label={t('parameters.seamCorrectionHeader')}>
<ParamSeamSize />
<ParamSeamBlur />
<ParamSeamStrength />

View File

@ -1,41 +1,45 @@
import { Divider, Flex } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import IAICollapse from 'common/components/IAICollapse';
import { Fragment, memo, useCallback } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import {
controlNetAdded,
controlNetSelector,
isControlNetEnabledToggled,
} from 'features/controlNet/store/controlNetSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { v4 as uuidv4 } from 'uuid';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import IAIButton from 'common/components/IAIButton';
import ControlNet from 'features/controlNet/components/ControlNet';
import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { controlNets, isEnabled } = controlNet;
return { controlNetsArray: map(controlNets), isEnabled };
const validControlNets = getValidControlNets(controlNets);
const activeLabel =
isEnabled && validControlNets.length > 0
? `${validControlNets.length} Active`
: undefined;
return { controlNetsArray: map(controlNets), activeLabel };
},
defaultSelectorOptions
);
const ParamControlNetCollapse = () => {
const { t } = useTranslation();
const { controlNetsArray, isEnabled } = useAppSelector(selector);
const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const handleClickControlNetToggle = useCallback(() => {
dispatch(isControlNetEnabledToggled());
}, [dispatch]);
const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
}, [dispatch]);
@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => {
}
return (
<IAICollapse
label={'ControlNet'}
isOpen={isEnabled}
onToggle={handleClickControlNetToggle}
withSwitch
>
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle />
{controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}>
{i > 0 && <Divider />}

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors';
@ -27,7 +28,8 @@ const selector = createSelector(
shouldUseSliders,
shift,
};
}
},
defaultSelectorOptions
);
const ParamCFGScale = () => {

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setHeight } from 'features/parameters/store/generationSlice';
@ -25,7 +26,8 @@ const selector = createSelector(
inputMax,
step,
};
}
},
defaultSelectorOptions
);
type ParamHeightProps = Omit<

View File

@ -1,37 +1,38 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setIterations } from 'features/parameters/store/generationSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { hotkeysSelector } from 'features/ui/store/hotkeysSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector([stateSelector], (state) => {
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
const isDisabled =
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const selector = createSelector(
[stateSelector],
(state) => {
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
const isDisabled =
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = state.hotkeys.shift ? fineStep : coarseStep;
const step = state.hotkeys.shift ? fineStep : coarseStep;
return {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
isDisabled,
};
});
return {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
isDisabled,
};
},
defaultSelectorOptions
);
const ParamIterations = () => {
const {

View File

@ -1,19 +1,19 @@
import { Box, Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect';
import VAESelect from 'features/system/components/VAESelect';
import { memo } from 'react';
import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => {
const ParamModelandVAE = () => {
return (
<Flex gap={3} w="full">
<Box w="25rem">
<ParamScheduler />
</Box>
<Box w="full">
<ModelSelect />
</Box>
<Box w="full">
<VAESelect />
</Box>
</Flex>
);
};
export default memo(ParamSchedulerAndModel);
export default memo(ParamModelandVAE);

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
@ -33,7 +34,8 @@ const selector = createSelector(
step,
shouldUseSliders,
};
}
},
defaultSelectorOptions
);
const ParamSteps = () => {

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { IAIFullSliderProps } from 'common/components/IAISlider';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setWidth } from 'features/parameters/store/generationSlice';
import { configSelector } from 'features/system/store/configSelectors';
@ -26,7 +26,8 @@ const selector = createSelector(
inputMax,
step,
};
}
},
defaultSelectorOptions
);
type ParamWidthProps = Omit<IAIFullSliderProps, 'label' | 'value' | 'onChange'>;

View File

@ -1,37 +1,39 @@
import { Flex } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { ParamHiresStrength } from './ParamHiresStrength';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ParamHiresStrength } from './ParamHiresStrength';
import { ParamHiresToggle } from './ParamHiresToggle';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamHiresCollapse = () => {
const { t } = useTranslation();
const hiresFix = useAppSelector(
(state: RootState) => state.postprocessing.hiresFix
);
const { activeLabel } = useAppSelector(selector);
const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setHiresFix(!hiresFix));
if (!isHiresEnabled) {
return null;
}
return (
<IAICollapse
label={t('parameters.hiresOptim')}
isOpen={hiresFix}
onToggle={handleToggle}
withSwitch
>
<IAICollapse label={t('parameters.hiresOptim')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamHiresToggle />
<ParamHiresStrength />
</Flex>
</IAICollapse>

View File

@ -23,7 +23,6 @@ export const ParamHiresToggle = () => {
return (
<IAISwitch
label={t('parameters.hiresOptim')}
fontSize="md"
isChecked={hiresFix}
onChange={handleChangeHiresFix}
/>

View File

@ -1,27 +1,33 @@
import { useTranslation } from 'react-i18next';
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamPerlinNoise from './ParamPerlinNoise';
import ParamNoiseThreshold from './ParamNoiseThreshold';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamNoiseThreshold from './ParamNoiseThreshold';
import { ParamNoiseToggle } from './ParamNoiseToggle';
import ParamPerlinNoise from './ParamPerlinNoise';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings } = state.generation;
return {
activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined,
};
},
defaultSelectorOptions
);
const ParamNoiseCollapse = () => {
const { t } = useTranslation();
const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled;
const shouldUseNoiseSettings = useAppSelector(
(state: RootState) => state.generation.shouldUseNoiseSettings
);
const dispatch = useAppDispatch();
const handleToggle = () =>
dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings));
const { activeLabel } = useAppSelector(selector);
if (!isNoiseEnabled) {
return null;
@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => {
return (
<IAICollapse
label={t('parameters.noiseSettings')}
isOpen={shouldUseNoiseSettings}
onToggle={handleToggle}
withSwitch
activeLabel={activeLabel}
>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamNoiseToggle />
<ParamPerlinNoise />
<ParamNoiseThreshold />
</Flex>

View File

@ -1,18 +1,31 @@
import { RootState } from 'app/store/store';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setThreshold } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings, threshold } = state.generation;
return {
isDisabled: !shouldUseNoiseSettings,
threshold,
};
},
defaultSelectorOptions
);
export default function ParamNoiseThreshold() {
const dispatch = useAppDispatch();
const threshold = useAppSelector(
(state: RootState) => state.generation.threshold
);
const { threshold, isDisabled } = useAppSelector(selector);
const { t } = useTranslation();
return (
<IAISlider
isDisabled={isDisabled}
label={t('parameters.noiseThreshold')}
min={0}
max={20}

View File

@ -0,0 +1,27 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export const ParamNoiseToggle = () => {
const dispatch = useAppDispatch();
const shouldUseNoiseSettings = useAppSelector(
(state: RootState) => state.generation.shouldUseNoiseSettings
);
const { t } = useTranslation();
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseNoiseSettings(e.target.checked));
return (
<IAISwitch
label="Enable Noise Settings"
isChecked={shouldUseNoiseSettings}
onChange={handleChange}
/>
);
};

View File

@ -1,16 +1,31 @@
import { RootState } from 'app/store/store';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setPerlin } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseNoiseSettings, perlin } = state.generation;
return {
isDisabled: !shouldUseNoiseSettings,
perlin,
};
},
defaultSelectorOptions
);
export default function ParamPerlinNoise() {
const dispatch = useAppDispatch();
const perlin = useAppSelector((state: RootState) => state.generation.perlin);
const { perlin, isDisabled } = useAppSelector(selector);
const { t } = useTranslation();
return (
<IAISlider
isDisabled={isDisabled}
label={t('parameters.perlinNoise')}
min={0}
max={1}

View File

@ -1,36 +1,46 @@
import { useTranslation } from 'react-i18next';
import { Box, Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSeamless } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamlessXAxis from './ParamSeamlessXAxis';
import ParamSeamlessYAxis from './ParamSeamlessYAxis';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const getActiveLabel = (seamlessXAxis: boolean, seamlessYAxis: boolean) => {
if (seamlessXAxis && seamlessYAxis) {
return 'X & Y';
}
if (seamlessXAxis) {
return 'X';
}
if (seamlessYAxis) {
return 'Y';
}
};
const selector = createSelector(
generationSelector,
(generation) => {
const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation;
const { seamlessXAxis, seamlessYAxis } = generation;
return { shouldUseSeamless, seamlessXAxis, seamlessYAxis };
const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis);
return { activeLabel };
},
defaultSelectorOptions
);
const ParamSeamlessCollapse = () => {
const { t } = useTranslation();
const { shouldUseSeamless } = useAppSelector(selector);
const { activeLabel } = useAppSelector(selector);
const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless));
if (!isSeamlessEnabled) {
return null;
}
@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => {
return (
<IAICollapse
label={t('parameters.seamlessTiling')}
isOpen={shouldUseSeamless}
onToggle={handleToggle}
withSwitch
activeLabel={activeLabel}
>
<Flex sx={{ gap: 5 }}>
<Box flexGrow={1}>

View File

@ -1,39 +1,39 @@
import { memo } from 'react';
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParamSymmetryHorizontal from './ParamSymmetryHorizontal';
import ParamSymmetryVertical from './ParamSymmetryVertical';
import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useTranslation } from 'react-i18next';
import ParamSymmetryToggle from './ParamSymmetryToggle';
const selector = createSelector(
stateSelector,
(state) => ({
activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined,
}),
defaultSelectorOptions
);
const ParamSymmetryCollapse = () => {
const { t } = useTranslation();
const shouldUseSymmetry = useAppSelector(
(state: RootState) => state.generation.shouldUseSymmetry
);
const { activeLabel } = useAppSelector(selector);
const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry));
if (!isSymmetryEnabled) {
return null;
}
return (
<IAICollapse
label={t('parameters.symmetry')}
isOpen={shouldUseSymmetry}
onToggle={handleToggle}
withSwitch
>
<IAICollapse label={t('parameters.symmetry')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamSymmetryToggle />
<ParamSymmetryHorizontal />
<ParamSymmetryVertical />
</Flex>

View File

@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() {
return (
<IAISwitch
label="Enable Symmetry"
isChecked={shouldUseSymmetry}
onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))}
/>

View File

@ -1,39 +1,42 @@
import ParamVariationWeights from './ParamVariationWeights';
import ParamVariationAmount from './ParamVariationAmount';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamVariationAmount from './ParamVariationAmount';
import { ParamVariationToggle } from './ParamVariationToggle';
import ParamVariationWeights from './ParamVariationWeights';
const selector = createSelector(
stateSelector,
(state) => {
const activeLabel = state.generation.shouldGenerateVariations
? 'Enabled'
: undefined;
return { activeLabel };
},
defaultSelectorOptions
);
const ParamVariationCollapse = () => {
const { t } = useTranslation();
const shouldGenerateVariations = useAppSelector(
(state: RootState) => state.generation.shouldGenerateVariations
);
const { activeLabel } = useAppSelector(selector);
const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled;
const dispatch = useAppDispatch();
const handleToggle = () =>
dispatch(setShouldGenerateVariations(!shouldGenerateVariations));
if (!isVariationEnabled) {
return null;
}
return (
<IAICollapse
label={t('parameters.variations')}
isOpen={shouldGenerateVariations}
onToggle={handleToggle}
withSwitch
>
<IAICollapse label={t('parameters.variations')} activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDirection: 'column' }}>
<ParamVariationToggle />
<ParamVariationAmount />
<ParamVariationWeights />
</Flex>

View File

@ -0,0 +1,27 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export const ParamVariationToggle = () => {
const dispatch = useAppDispatch();
const shouldGenerateVariations = useAppSelector(
(state: RootState) => state.generation.shouldGenerateVariations
);
const { t } = useTranslation();
const handleChange = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldGenerateVariations(e.target.checked));
return (
<IAISwitch
label="Enable Variations"
isChecked={shouldGenerateVariations}
onChange={handleChange}
/>
);
};

View File

@ -14,6 +14,7 @@ import {
SeedParam,
StepsParam,
StrengthParam,
VAEParam,
WidthParam,
} from './parameterZodSchemas';
@ -47,7 +48,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
model: ModelParam;
shouldUseSeamless: boolean;
vae: VAEParam;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
}
@ -81,9 +82,9 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0,
model: '',
shouldUseSeamless: false,
seamlessXAxis: true,
seamlessYAxis: true,
vae: '',
seamlessXAxis: false,
seamlessYAxis: false,
};
const initialState: GenerationState = initialGenerationState;
@ -141,9 +142,6 @@ export const generationSlice = createSlice({
setImg2imgStrength: (state, action: PayloadAction<number>) => {
state.img2imgStrength = action.payload;
},
setSeamless: (state, action: PayloadAction<boolean>) => {
state.shouldUseSeamless = action.payload;
},
setSeamlessXAxis: (state, action: PayloadAction<boolean>) => {
state.seamlessXAxis = action.payload;
},
@ -216,6 +214,9 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload;
},
vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
@ -260,8 +261,8 @@ export const {
setVerticalSymmetrySteps,
initialImageChanged,
modelSelected,
vaeSelected,
setShouldUseNoiseSettings,
setSeamless,
setSeamlessXAxis,
setSeamlessYAxis,
} = generationSlice.actions;

View File

@ -135,6 +135,15 @@ export const zModel = z.string();
* Type alias for model parameter, inferred from its zod schema
*/
export type ModelParam = z.infer<typeof zModel>;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
/**
* Validates/type-guards a value as a model parameter
*/

View File

@ -1,125 +0,0 @@
import {
Button,
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
onClick,
}: {
text: string;
onClick?: () => void;
}) {
return (
<Flex
position="relative"
width="50%"
height={40}
justifyContent="center"
alignItems="center"
onClick={onClick}
as={Button}
>
<Text fontWeight="bold">{text}</Text>
</Flex>
);
}
export default function AddModel() {
const { isOpen, onOpen, onClose } = useDisclosure();
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const addModelModalClose = () => {
onClose();
dispatch(setAddNewModelUIOption(null));
};
return (
<>
<IAIButton
aria-label={t('modelManager.addNewModel')}
tooltip={t('modelManager.addNewModel')}
onClick={onOpen}
size="sm"
>
<Flex columnGap={2} alignItems="center">
<FaPlus />
{t('modelManager.addNew')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={addModelModalClose}
size="3xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (
<Flex columnGap={4}>
<AddModelBox
text={t('modelManager.addCheckpointModel')}
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
/>
<AddModelBox
text={t('modelManager.addDiffuserModel')}
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
/>
</Flex>
)}
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@ -1,339 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
Flex,
FormControl,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import type { FieldInputProps, FormikProps } from 'formik';
import { isEqual, pickBy } from 'lodash-es';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function CheckpointModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeModelConfigProps>({
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
default: false,
format: 'ckpt',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
config: retrievedModel[openModel]?.config,
weights: retrievedModel[openModel]?.weights,
vae: retrievedModel[openModel]?.vae,
width: retrievedModel[openModel]?.width,
height: retrievedModel[openModel]?.height,
default: retrievedModel[openModel]?.default,
format: 'ckpt',
});
}
}, [model_list, openModel]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(
addNewModel({
...values,
width: Number(values.width),
height: Number(values.height),
})
);
};
return openModel ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between">
<Text fontSize="lg" fontWeight="bold">
{openModel}
</Text>
<ModelConvert model={openModel} />
</Flex>
<Flex
flexDirection="column"
maxHeight={window.innerHeight - 270}
overflowY="scroll"
paddingInlineEnd={8}
>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.configValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<IAIFormErrorMessage>
{errors.weights}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<IAIFormErrorMessage>
{errors.width}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.widthValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<IAIFormErrorMessage>
{errors.height}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.heightValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,281 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import { isEqual, pickBy } from 'lodash-es';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export default function DiffusersModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeDiffusersModelConfigProps>({
name: '',
description: '',
repo_id: '',
path: '',
vae: { repo_id: '', path: '' },
default: false,
format: 'diffusers',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
path:
retrievedModel[openModel]?.path &&
retrievedModel[openModel]?.path !== 'None'
? retrievedModel[openModel]?.path
: '',
repo_id:
retrievedModel[openModel]?.repo_id &&
retrievedModel[openModel]?.repo_id !== 'None'
? retrievedModel[openModel]?.repo_id
: '',
vae: {
repo_id: retrievedModel[openModel]?.vae?.repo_id
? retrievedModel[openModel]?.vae?.repo_id
: '',
path: retrievedModel[openModel]?.vae?.path
? retrievedModel[openModel]?.vae?.path
: '',
},
default: retrievedModel[openModel]?.default,
format: 'diffusers',
});
}
}, [model_list, openModel]);
const editModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToEdit = values;
if (values.path === '') delete diffusersModelToEdit.path;
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
dispatch(addNewModel(values));
};
return openModel ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center">
<Text fontSize="lg" fontWeight="bold">
{openModel}
</Text>
</Flex>
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Path */}
<FormControl
isInvalid={!!errors.path && touched.path}
isRequired
>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="full"
/>
{!!errors.path && touched.path ? (
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
width="full"
/>
{!!errors.repo_id && touched.repo_id ? (
<IAIFormErrorMessage>
{errors.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.repoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Path */}
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
width="full"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<IAIFormErrorMessage>
{errors.vae?.path}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="full"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<IAIFormErrorMessage>
{errors.vae?.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeRepoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

Some files were not shown because too many files have changed in this diff Show More