(wip) Model Manager 3.0 UI (#3586)

...
This commit is contained in:
blessedcoolant 2023-07-04 17:34:06 +12:00 committed by GitHub
commit 92b163e95c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 1927 additions and 3183 deletions

View File

@ -2,17 +2,17 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from fastapi import Query from fastapi import Query, Body
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel): class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE") repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE") path: Optional[str] = Field(description="The path to the VAE")
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ImportModelRequest(BaseModel): class ImportModelResponse(BaseModel):
name: str = Field(description="A model path, repo_id or URL to import") name: str = Field(description="The name of the imported model")
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files') # base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
@ -86,7 +89,6 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
@models_router.post( @models_router.post(
"/", "/",
operation_id="update_model", operation_id="update_model",
@ -109,27 +111,38 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/", "/import",
operation_id="import_model", operation_id="import_model",
responses={200: {"status": "success"}}, responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
},
status_code=201,
response_model=ImportModelResponse
) )
async def import_model( async def import_model(
model_request: ImportModelRequest name: str = Query(description="A model path, repo_id or URL to import"),
) -> None: prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
""" Add Model """ ) -> ImportModelResponse:
items_to_import = set([model_request.name]) """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
if len(installed_models) > 0: if info := installed_models.get(name):
logger.info(f'Successfully imported {model_request.name}') logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
)
else: else:
logger.error(f'Model {model_request.name} not imported') logger.error(f'Model {name} not imported')
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported') raise HTTPException(status_code=404, detail=f'Model {name} not found')
@models_router.delete( @models_router.delete(
"/{model_name}", "/{model_name}",

View File

@ -1,11 +1,12 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy import copy
from typing import List, Literal, Optional
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from pydantic import BaseModel, Field
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
@ -30,7 +31,6 @@ class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
@ -43,25 +43,26 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class PipelineModelField(BaseModel): class MainModelField(BaseModel):
"""Pipeline model field""" """Main model field"""
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader" type: Literal["main_model_loader"] = "main_model_loader"
model: PipelineModelField = Field(description="The model to load") model: MainModelField = Field(description="The model to load")
# TODO: precision? # TODO: precision?
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Model Loader",
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {
"model": "model" "model": "model"
@ -175,6 +176,14 @@ class LoraLoaderInvocation(BaseInvocation):
unet: Optional[UNetField] = Field(description="UNet model for applying lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora") clip: Optional[ClipField] = Field(description="Clip model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Lora Loader",
"tags": ["lora", "loader"],
},
}
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
# TODO: ui rewrite # TODO: ui rewrite
@ -221,3 +230,56 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
#fmt: on
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "VAE Loader",
"tags": ["vae", "loader"],
"type_hints": {
"vae_model": "vae_model"
}
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
vae=VaeField(
vae = ModelInfo(
model_name = model_name,
base_model = base_model,
model_type = model_type,
)
)
)

View File

@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path = None) -> None: def commit(self, conf_file: Path = None) -> None:
""" """
@ -361,3 +384,24 @@ class ModelManagerService(ModelManagerServiceBase):
def logger(self): def logger(self):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)

View File

@ -18,7 +18,7 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
@ -166,17 +166,22 @@ class ModelInstall(object):
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path) self.heuristic_import(path)
job += 1 job += 1
self.mgr.commit() self.mgr.commit()
def heuristic_install(self, def heuristic_import(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]: models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
'''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
if not models_installed: if not models_installed:
models_installed = set() models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID # A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
@ -185,24 +190,24 @@ class ModelInstall(object):
try: try:
# checkpoint file, or similar # checkpoint file, or similar
if path.is_file(): if path.is_file():
models_installed.add(self._install_path(path)) models_installed.update(self._install_path(path))
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
models_installed.add(self._install_path(path)) models_installed.update(self._install_path(path))
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():
for child in path.iterdir(): for child in path.iterdir():
self.heuristic_install(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(path).split('/')) == 2: elif len(str(path).split('/')) == 2:
models_installed.add(self._install_repo(str(path))) models_installed.update(self._install_repo(str(path)))
# a URL # a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.add(self._install_url(model_path_id_or_url)) models_installed.update(self._install_url(model_path_id_or_url))
else: else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
@ -214,24 +219,25 @@ class ModelInstall(object):
# install a model from a local path. The optional info parameter is there to prevent # install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
try: try:
# logger.debug(f'Probing {path}') model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name model_name = path.stem if info.format=='checkpoint' else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info) attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = model_name, model_result = self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
model_type = info.model_type, model_type = info.model_type,
model_attributes = attributes, model_attributes = attributes,
) )
except Exception as e: except Exception as e:
logger.warning(f'{str(e)} Skipping registration.') logger.warning(f'{str(e)} Skipping registration.')
return path return {}
return {str(path): model_result}
def _install_url(self, url: str)->Path: def _install_url(self, url: str)->dict:
# copy to a staging area, probe, import and delete # copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
@ -244,7 +250,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info) return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->Path: def _install_repo(self, repo_id: str)->dict:
hinfo = HfApi().model_info(repo_id) hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically # we try to figure out how to download this most economically

View File

@ -1,7 +1,7 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -233,14 +233,14 @@ import hashlib
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union, Set, Callable, types from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree from shutil import rmtree
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
"Raised when an invalid model is requested" "Raised when an invalid model is requested"
pass pass
MAX_CACHE_SIZE = 6.0 # GB class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
MAX_CACHE_SIZE = 6.0 # GB
class ConfigMeta(BaseModel): class ConfigMeta(BaseModel):
version: str version: str
@ -571,13 +576,16 @@ class ModelManager(object):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
) -> None: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and the On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
The returned dict has the same format as the dict returned by
model_info().
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@ -601,12 +609,18 @@ class ModelManager(object):
old_model_cache.unlink() old_model_cache.unlink()
# remove in-memory cache # remove in-memory cache
# note: it not garantie to release memory(model can has other references) # note: it not guaranteed to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids: for cache_id in cache_ids:
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config self.models[model_key] = model_config
return AddModelResult(
name = model_name,
model_type = model_type,
base_model = base_model,
config = model_config,
)
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
@ -729,7 +743,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->set[Path]: def autoimport(self)->Dict[str, AddModelResult]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
''' '''
@ -742,7 +756,6 @@ class ModelManager(object):
prediction_type_helper = ask_user_for_prediction_type, prediction_type_helper = ask_user_for_prediction_type,
) )
installed = set()
scanned_dirs = set() scanned_dirs = set()
config = self.app_config config = self.app_config
@ -756,13 +769,14 @@ class ModelManager(object):
continue continue
self.logger.info(f'Scanning {autodir} for models to import') self.logger.info(f'Scanning {autodir} for models to import')
installed = dict()
autodir = self.app_config.root_path / autodir autodir = self.app_config.root_path / autodir
if not autodir.exists(): if not autodir.exists():
continue continue
items_scanned = 0 items_scanned = 0
new_models_found = set() new_models_found = dict()
for root, dirs, files in os.walk(autodir): for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files) items_scanned += len(dirs) + len(files)
@ -772,7 +786,7 @@ class ModelManager(object):
scanned_dirs.add(path) scanned_dirs.add(path)
continue continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
new_models_found.update(installer.heuristic_install(path)) new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path) scanned_dirs.add(path)
for f in files: for f in files:
@ -780,7 +794,7 @@ class ModelManager(object):
if path in known_paths or path.parent in scanned_dirs: if path in known_paths or path.parent in scanned_dirs:
continue continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_install(path)) new_models_found.update(installer.heuristic_import(path))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found) installed.update(new_models_found)
@ -790,7 +804,7 @@ class ModelManager(object):
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Set[str]: )->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -803,17 +817,20 @@ class ModelManager(object):
generally impossible to do this programmatically, so the generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose. prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
''' '''
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = set() successfully_installed = dict()
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self)
for thing in items_to_import: for thing in items_to_import:
try: try:
installed = installer.heuristic_install(thing) installed = installer.heuristic_import(thing)
successfully_installed.update(installed) successfully_installed.update(installed)
except Exception as e: except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}') self.logger.warning(f'{thing} could not be imported: {str(e)}')

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script> <script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -24,16 +24,13 @@
}, },
"common": { "common": {
"hotkeysLabel": "Hotkeys", "hotkeysLabel": "Hotkeys",
"themeLabel": "Theme", "darkMode": "Dark Mode",
"lightMode": "Light Mode",
"languagePickerLabel": "Language", "languagePickerLabel": "Language",
"reportBugLabel": "Report Bug", "reportBugLabel": "Report Bug",
"githubLabel": "Github", "githubLabel": "Github",
"discordLabel": "Discord", "discordLabel": "Discord",
"settingsLabel": "Settings", "settingsLabel": "Settings",
"darkTheme": "Dark",
"lightTheme": "Light",
"greenTheme": "Green",
"oceanTheme": "Ocean",
"langArabic": "العربية", "langArabic": "العربية",
"langEnglish": "English", "langEnglish": "English",
"langDutch": "Nederlands", "langDutch": "Nederlands",
@ -55,6 +52,7 @@
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",
@ -336,6 +334,7 @@
"modelManager": { "modelManager": {
"modelManager": "Model Manager", "modelManager": "Model Manager",
"model": "Model", "model": "Model",
"vae": "VAE",
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
@ -351,6 +350,7 @@
"scanForModels": "Scan For Models", "scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"baseModel": "Base Model",
"name": "Name", "name": "Name",
"nameValidationMsg": "Enter a name for your model", "nameValidationMsg": "Enter a name for your model",
"description": "Description", "description": "Description",
@ -363,6 +363,7 @@
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.", "vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID", "vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE", "vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width", "width": "Width",
@ -524,7 +525,8 @@
"initialImage": "Initial Image", "initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel", "showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview" "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
@ -547,7 +549,8 @@
"general": "General", "general": "General",
"generation": "Generation", "generation": "Generation",
"ui": "User Interface", "ui": "User Interface",
"availableSchedulers": "Available Schedulers" "favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
}, },
"toast": { "toast": {
"serverError": "Server Error", "serverError": "Server Error",

View File

@ -67,6 +67,7 @@
"@fontsource-variable/inter": "^5.0.3", "@fontsource-variable/inter": "^5.0.3",
"@fontsource/inter": "^5.0.3", "@fontsource/inter": "^5.0.3",
"@mantine/core": "^6.0.14", "@mantine/core": "^6.0.14",
"@mantine/form": "^6.0.15",
"@mantine/hooks": "^6.0.14", "@mantine/hooks": "^6.0.14",
"@reduxjs/toolkit": "^1.9.5", "@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5", "@roarr/browser-log-writer": "^1.1.5",

View File

@ -53,6 +53,7 @@
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"batch": "Batch Manager", "batch": "Batch Manager",
"modelmanager": "Model Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",
@ -334,6 +335,7 @@
"modelManager": { "modelManager": {
"modelManager": "Model Manager", "modelManager": "Model Manager",
"model": "Model", "model": "Model",
"vae": "VAE",
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
@ -349,6 +351,7 @@
"scanForModels": "Scan For Models", "scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"baseModel": "Base Model",
"name": "Name", "name": "Name",
"nameValidationMsg": "Enter a name for your model", "nameValidationMsg": "Enter a name for your model",
"description": "Description", "description": "Description",
@ -361,6 +364,7 @@
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"vaeLocation": "VAE Location", "vaeLocation": "VAE Location",
"vaeLocationValidationMsg": "Path to where your VAE is located.", "vaeLocationValidationMsg": "Path to where your VAE is located.",
"variant": "Variant",
"vaeRepoID": "VAE Repo ID", "vaeRepoID": "VAE Repo ID",
"vaeRepoIDValidationMsg": "Online repository of your VAE", "vaeRepoIDValidationMsg": "Online repository of your VAE",
"width": "Width", "width": "Width",

View File

@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo, useEffect } from 'react'; import { ReactNode, memo, useEffect } from 'react';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};

View File

@ -3,20 +3,21 @@ import { memo } from 'react';
import { InputFieldTemplate, InputFieldValue } from '../types/types'; import { InputFieldTemplate, InputFieldValue } from '../types/types';
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent'; import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
type InputFieldComponentProps = { type InputFieldComponentProps = {
nodeId: string; nodeId: string;
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') { if (type === 'array' && template.type === 'array') {
return ( return (
<ArrayInputFieldComponent <ArrayInputFieldComponent

View File

@ -6,13 +6,13 @@ import {
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({ const { data: mainModels } = useListModelsQuery({
model_type: 'main', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {
if (!pipelineModels) { if (!mainModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model) { if (!model) {
return; return;
} }
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
}); });
return data; return data;
}, [pipelineModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], () => mainModels?.entities[field.value ?? mainModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value] [mainModels?.entities, mainModels?.ids, field.value]
); );
const handleValueChanged = useCallback( const handleValueChanged = useCallback(
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
); );
useEffect(() => { useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) { if (field.value && mainModels?.ids.includes(field.value)) {
return; return;
} }
const firstModel = pipelineModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
return; return;
} }
handleValueChanged(firstModel); handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]); }, [field.value, handleValueChanged, mainModels?.ids]);
return ( return (
<IAIMantineSelect <IAIMantineSelect

View File

@ -0,0 +1,97 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
VaeModelInputFieldTemplate,
VaeModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const VaeModelInputFieldComponent = (
props: FieldComponentProps<
VaeModelInputFieldValue,
VaeModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({
model_type: 'vae',
});
const selectedModel = useMemo(
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
[vaeModels?.entities, vaeModels?.ids, field.value]
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && vaeModels?.ids.includes(field.value)) {
return;
}
handleValueChanged('auto');
}, [field.value, handleValueChanged, vaeModels?.ids]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(VaeModelInputFieldComponent);

View File

@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip', ClipField: 'clip',
VaeField: 'vae', VaeField: 'vae',
model: 'model', model: 'model',
vae_model: 'vae_model',
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
@ -116,6 +117,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model', title: 'Model',
description: 'Models are models.', description: 'Models are models.',
}, },
vae_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Model',
description: 'Models are models.',
},
array: { array: {
color: 'gray', color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'), colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -64,6 +64,7 @@ export type FieldType =
| 'vae' | 'vae'
| 'control' | 'control'
| 'model' | 'model'
| 'vae_model'
| 'array' | 'array'
| 'item' | 'item'
| 'color' | 'color'
@ -91,6 +92,7 @@ export type InputFieldValue =
| ControlInputFieldValue | ControlInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | ModelInputFieldValue
| VaeModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
| ItemInputFieldValue | ItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
@ -116,6 +118,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate | ControlInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
| ItemInputFieldTemplate | ItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
value?: string; value?: string;
}; };
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {
type: 'array'; type: 'array';
value?: (string | number)[]; value?: (string | number)[];
@ -305,6 +313,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
type: 'conditioning'; type: 'conditioning';
}; };
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'unet';
};
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'clip';
};
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'vae';
};
export type ControlInputFieldTemplate = InputFieldTemplateBase & { export type ControlInputFieldTemplate = InputFieldTemplateBase & {
default: undefined; default: undefined;
type: 'control'; type: 'control';
@ -322,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model'; type: 'model';
}; };
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'vae_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'array'; type: 'array';

View File

@ -3,27 +3,28 @@ import { OpenAPIV3 } from 'openapi-types';
import { FIELD_TYPE_MAP } from '../types/constants'; import { FIELD_TYPE_MAP } from '../types/constants';
import { isSchemaObject } from '../types/typeGuards'; import { isSchemaObject } from '../types/typeGuards';
import { import {
BooleanInputFieldTemplate,
EnumInputFieldTemplate,
FloatInputFieldTemplate,
ImageInputFieldTemplate,
IntegerInputFieldTemplate,
LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
UNetInputFieldTemplate,
ClipInputFieldTemplate,
VaeInputFieldTemplate,
ControlInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate, ArrayInputFieldTemplate,
ItemInputFieldTemplate, BooleanInputFieldTemplate,
ClipInputFieldTemplate,
ColorInputFieldTemplate, ColorInputFieldTemplate,
InputFieldTemplateBase, ConditioningInputFieldTemplate,
OutputFieldTemplate, ControlInputFieldTemplate,
TypeHints, EnumInputFieldTemplate,
FieldType, FieldType,
FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate, ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate,
InputFieldTemplateBase,
IntegerInputFieldTemplate,
ItemInputFieldTemplate,
LatentsInputFieldTemplate,
ModelInputFieldTemplate,
OutputFieldTemplate,
StringInputFieldTemplate,
TypeHints,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
} from '../types/types'; } from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description'; export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
return template; return template;
}; };
const buildVaeModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
const template: VaeModelInputFieldTemplate = {
...baseField,
type: 'vae_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) { if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField }); return buildEnumInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -75,6 +75,10 @@ export const buildInputFieldValue = (
if (template.type === 'model') { if (template.type === 'model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}
} }
return fieldValue; return fieldValue;

View File

@ -0,0 +1,68 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
INPAINT,
INPAINT_GRAPH,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
export const addVAEToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { vae: vaeId } = state.generation;
const vae_model = modelIdToVAEModelField(vaeId);
if (vaeId !== 'auto') {
graph.nodes[VAE_LOADER] = {
type: 'vae_loader',
id: VAE_LOADER,
vae_model,
};
}
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
});
}
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
if (graph.id === INPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
});
}
};

View File

@ -1,31 +1,26 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageDTO, ImageDTO,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { log } from 'app/logging/useLogger'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { import {
ITERATE, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE, RESIZE,
} from './constants'; } from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -52,7 +47,7 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -81,9 +76,9 @@ export const buildCanvasImageToImageGraph = (
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -110,7 +105,7 @@ export const buildCanvasImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -120,7 +115,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -128,16 +123,6 @@ export const buildCanvasImageToImageGraph = (
field: 'clip', field: 'clip',
}, },
}, },
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -170,17 +155,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -277,6 +252,9 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph` // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state); addDynamicPromptsToGraph(graph, state);

View File

@ -1,23 +1,24 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageDTO, ImageDTO,
InpaintInvocation, InpaintInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { log } from 'app/logging/useLogger'; import { addVAEToGraph } from './addVAEToGraph';
import { import {
INPAINT,
INPAINT_GRAPH,
ITERATE, ITERATE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
INPAINT_GRAPH,
INPAINT,
} from './constants'; } from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -55,7 +56,7 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
@ -101,9 +102,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
@ -142,7 +143,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -152,7 +153,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -162,7 +163,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -170,16 +171,6 @@ export const buildCanvasInpaintGraph = (
field: 'unet', field: 'unet',
}, },
}, },
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
},
{ {
source: { source: {
node_id: RANGE_OF_SIZE, node_id: RANGE_OF_SIZE,
@ -203,6 +194,9 @@ export const buildCanvasInpaintGraph = (
], ],
}; };
// Add VAE
addVAEToGraph(graph, state);
// handle seed // handle seed
if (shouldRandomizeSeed) { if (shouldRandomizeSeed) {
// Random int node to generate the starting seed // Random int node to generate the starting seed

View File

@ -1,21 +1,18 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -38,7 +35,7 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -76,9 +73,9 @@ export const buildCanvasTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -109,7 +106,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -119,7 +116,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -129,7 +126,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -147,16 +144,6 @@ export const buildCanvasTextToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -170,6 +157,9 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph` // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state); addDynamicPromptsToGraph(graph, state);

View File

@ -1,28 +1,29 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageCollectionInvocation, ImageCollectionInvocation,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
IterateInvocation, IterateInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { log } from 'app/logging/useLogger'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { import {
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE, RESIZE,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -69,7 +70,7 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
@ -89,9 +90,9 @@ export const buildLinearImageToImageGraph = (
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -118,7 +119,7 @@ export const buildLinearImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -128,7 +129,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -136,16 +137,6 @@ export const buildLinearImageToImageGraph = (
field: 'clip', field: 'clip',
}, },
}, },
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -176,19 +167,10 @@ export const buildLinearImageToImageGraph = (
field: 'noise', field: 'noise',
}, },
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -322,6 +304,8 @@ export const buildLinearImageToImageGraph = (
}, },
}); });
} }
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph` // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state); addDynamicPromptsToGraph(graph, state);

View File

@ -1,17 +1,18 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState state: RootState
@ -27,7 +28,7 @@ export const buildLinearTextToImageGraph = (
height, height,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -65,9 +66,9 @@ export const buildLinearTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -98,7 +99,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -108,7 +109,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -118,7 +119,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -136,16 +137,6 @@ export const buildLinearTextToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: PIPELINE_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -159,6 +150,9 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
// Add Custom VAE Support
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph` // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state); addDynamicPromptsToGraph(graph, state);

View File

@ -1,10 +1,11 @@
import { Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { v4 as uuidv4 } from 'uuid';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -27,7 +28,13 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') { if (field.type === 'model') {
if (field.value) { if (field.value) {
return modelIdToPipelineModelField(field.value); return modelIdToMainModelField(field.value);
}
}
if (field.type === 'vae_model') {
if (field.value) {
return modelIdToVAEModelField(field.value);
} }
} }

View File

@ -7,7 +7,8 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader'; export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,16 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,18 +0,0 @@
import { BaseModelType, PipelineModelField } from 'services/api/types';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -0,0 +1,16 @@
import { BaseModelType, VAEModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: VAEModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,19 +1,19 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import ModelSelect from 'features/system/components/ModelSelect'; import ModelSelect from 'features/system/components/ModelSelect';
import VAESelect from 'features/system/components/VAESelect';
import { memo } from 'react'; import { memo } from 'react';
import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => { const ParamModelandVAE = () => {
return ( return (
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box w="25rem">
<ParamScheduler />
</Box>
<Box w="full"> <Box w="full">
<ModelSelect /> <ModelSelect />
</Box> </Box>
<Box w="full">
<VAESelect />
</Box>
</Flex> </Flex>
); );
}; };
export default memo(ParamSchedulerAndModel); export default memo(ParamModelandVAE);

View File

@ -14,6 +14,7 @@ import {
SeedParam, SeedParam,
StepsParam, StepsParam,
StrengthParam, StrengthParam,
VAEParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
@ -47,6 +48,7 @@ export interface GenerationState {
horizontalSymmetrySteps: number; horizontalSymmetrySteps: number;
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: ModelParam; model: ModelParam;
vae: VAEParam;
shouldUseSeamless: boolean; shouldUseSeamless: boolean;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = {
horizontalSymmetrySteps: 0, horizontalSymmetrySteps: 0,
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: '', model: '',
vae: '',
shouldUseSeamless: false, shouldUseSeamless: false,
seamlessXAxis: true, seamlessXAxis: true,
seamlessYAxis: true, seamlessYAxis: true,
@ -216,6 +219,9 @@ export const generationSlice = createSlice({
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
}, },
vaeSelected: (state, action: PayloadAction<string>) => {
state.vae = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
@ -260,6 +266,7 @@ export const {
setVerticalSymmetrySteps, setVerticalSymmetrySteps,
initialImageChanged, initialImageChanged,
modelSelected, modelSelected,
vaeSelected,
setShouldUseNoiseSettings, setShouldUseNoiseSettings,
setSeamless, setSeamless,
setSeamlessXAxis, setSeamlessXAxis,

View File

@ -135,6 +135,15 @@ export const zModel = z.string();
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type ModelParam = z.infer<typeof zModel>; export type ModelParam = z.infer<typeof zModel>;
/**
* Zod schema for VAE parameter
* TODO: Make this a dynamically generated enum?
*/
export const zVAE = z.string();
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type VAEParam = z.infer<typeof zVAE>;
/** /**
* Validates/type-guards a value as a model parameter * Validates/type-guards a value as a model parameter
*/ */

View File

@ -1,125 +0,0 @@
import {
Button,
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
onClick,
}: {
text: string;
onClick?: () => void;
}) {
return (
<Flex
position="relative"
width="50%"
height={40}
justifyContent="center"
alignItems="center"
onClick={onClick}
as={Button}
>
<Text fontWeight="bold">{text}</Text>
</Flex>
);
}
export default function AddModel() {
const { isOpen, onOpen, onClose } = useDisclosure();
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const addModelModalClose = () => {
onClose();
dispatch(setAddNewModelUIOption(null));
};
return (
<>
<IAIButton
aria-label={t('modelManager.addNewModel')}
tooltip={t('modelManager.addNewModel')}
onClick={onOpen}
size="sm"
>
<Flex columnGap={2} alignItems="center">
<FaPlus />
{t('modelManager.addNew')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={addModelModalClose}
size="3xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (
<Flex columnGap={4}>
<AddModelBox
text={t('modelManager.addCheckpointModel')}
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
/>
<AddModelBox
text={t('modelManager.addDiffuserModel')}
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
/>
</Flex>
)}
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@ -1,339 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
Flex,
FormControl,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import type { FieldInputProps, FormikProps } from 'formik';
import { isEqual, pickBy } from 'lodash-es';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function CheckpointModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeModelConfigProps>({
name: '',
description: '',
config: 'configs/stable-diffusion/v1-inference.yaml',
weights: '',
vae: '',
width: 512,
height: 512,
default: false,
format: 'ckpt',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
config: retrievedModel[openModel]?.config,
weights: retrievedModel[openModel]?.weights,
vae: retrievedModel[openModel]?.vae,
width: retrievedModel[openModel]?.width,
height: retrievedModel[openModel]?.height,
default: retrievedModel[openModel]?.default,
format: 'ckpt',
});
}
}, [model_list, openModel]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(
addNewModel({
...values,
width: Number(values.width),
height: Number(values.height),
})
);
};
return openModel ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between">
<Text fontSize="lg" fontWeight="bold">
{openModel}
</Text>
<ModelConvert model={openModel} />
</Flex>
<Flex
flexDirection="column"
maxHeight={window.innerHeight - 270}
overflowY="scroll"
paddingInlineEnd={8}
>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.configValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<IAIFormErrorMessage>
{errors.weights}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<IAIFormErrorMessage>
{errors.width}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.widthValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<IAIFormErrorMessage>
{errors.height}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.heightValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,281 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import { isEqual, pickBy } from 'lodash-es';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export default function DiffusersModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] =
useState<InvokeDiffusersModelConfigProps>({
name: '',
description: '',
repo_id: '',
path: '',
vae: { repo_id: '', path: '' },
default: false,
format: 'diffusers',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
path:
retrievedModel[openModel]?.path &&
retrievedModel[openModel]?.path !== 'None'
? retrievedModel[openModel]?.path
: '',
repo_id:
retrievedModel[openModel]?.repo_id &&
retrievedModel[openModel]?.repo_id !== 'None'
? retrievedModel[openModel]?.repo_id
: '',
vae: {
repo_id: retrievedModel[openModel]?.vae?.repo_id
? retrievedModel[openModel]?.vae?.repo_id
: '',
path: retrievedModel[openModel]?.vae?.path
? retrievedModel[openModel]?.vae?.path
: '',
},
default: retrievedModel[openModel]?.default,
format: 'diffusers',
});
}
}, [model_list, openModel]);
const editModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToEdit = values;
if (values.path === '') delete diffusersModelToEdit.path;
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
dispatch(addNewModel(values));
};
return openModel ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center">
<Text fontSize="lg" fontWeight="bold">
{openModel}
</Text>
</Flex>
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
<Formik
enableReinitialize={true}
initialValues={editModelFormValues}
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<IAIFormErrorMessage>
{errors.description}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Path */}
<FormControl
isInvalid={!!errors.path && touched.path}
isRequired
>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="full"
/>
{!!errors.path && touched.path ? (
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
width="full"
/>
{!!errors.repo_id && touched.repo_id ? (
<IAIFormErrorMessage>
{errors.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.repoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Path */}
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
width="full"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<IAIFormErrorMessage>
{errors.vae?.path}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="full"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<IAIFormErrorMessage>
{errors.vae?.repo_id}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.vaeRepoIDValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,313 +0,0 @@
import {
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Radio,
RadioGroup,
Text,
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
// import { mergeDiffusersModels } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import * as InvokeAI from 'app/types/invokeai';
import IAISlider from 'common/components/IAISlider';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
export default function MergeModels() {
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
const diffusersModels = useAppSelector(diffusersModelsSelector);
const { t } = useTranslation();
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
);
const [modelThree, setModelThree] = useState<string>('none');
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
};
return (
<>
<IAIButton onClick={onOpen} size="sm">
<Flex columnGap={2} alignItems="center">
{t('modelManager.mergeModels')}
</Flex>
</IAIButton>
<Modal
isOpen={isOpen}
onClose={onClose}
size="4xl"
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent fontFamily="Inter" margin="auto" paddingInlineEnd={4}>
<ModalHeader>{t('modelManager.mergeModels')}</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex flexDirection="column" rowGap={4}>
<Flex
sx={{
flexDirection: 'column',
marginBottom: 4,
padding: 4,
borderRadius: 'base',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
<Text fontSize="sm" variant="subtext">
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelMergeInterp('weighted_sum');
}
}}
/>
</Flex>
<IAIInput
label={t('modelManager.mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
/>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<IAISlider
label={t('modelManager.alpha')}
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
withSliderMarks
/>
<Text variant="subtext" fontSize="sm">
{t('modelManager.modelMergeAlphaHelp')}
</Text>
</Flex>
<Flex
sx={{
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
{t('modelManager.interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v:
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference'
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">
{t('modelManager.weightedSum')}
</Text>
</Radio>
<Radio value="sigmoid">
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
</Radio>
<Radio value="inv_sigmoid">
<Text fontSize="sm">
{t('modelManager.inverseSigmoid')}
</Text>
</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip
label={t(
'modelManager.modelMergeInterpAddDifferenceHelp'
)}
>
<Text fontSize="sm">
{t('modelManager.addDifference')}
</Text>
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Flex columnGap={4}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.mergedModelSaveLocation')}
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') =>
setModelMergeSaveLocType(v)
}
>
<Flex columnGap={4}>
<Radio value="root">
<Text fontSize="sm">
{t('modelManager.invokeAIFolder')}
</Text>
</Radio>
<Radio value="custom">
<Text fontSize="sm">{t('modelManager.custom')}</Text>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<IAIInput
label={t('modelManager.mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
fontWeight="500"
/>
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' &&
modelMergeCustomSaveLoc === ''
}
>
{t('modelManager.merge')}
</IAIButton>
</Flex>
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@ -1,76 +0,0 @@
import {
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement } from 'react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import type { ReactElement } from 'react';
import CheckpointModelEdit from './CheckpointModelEdit';
import DiffusersModelEdit from './DiffusersModelEdit';
import ModelList from './ModelList';
type ModelManagerModalProps = {
children: ReactElement;
};
export default function ModelManagerModal({
children,
}: ModelManagerModalProps) {
const {
isOpen: isModelManagerModalOpen,
onOpen: onModelManagerModalOpen,
onClose: onModelManagerModalClose,
} = useDisclosure();
const model_list = useAppSelector(
(state: RootState) => state.system.model_list
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const { t } = useTranslation();
return (
<>
{cloneElement(children, {
onClick: onModelManagerModalOpen,
})}
<Modal
isOpen={isModelManagerModalOpen}
onClose={onModelManagerModalClose}
size="full"
>
<ModalOverlay />
<ModalContent>
<ModalCloseButton />
<ModalHeader>{t('modelManager.modelManager')}</ModalHeader>
<ModalBody>
<Flex width="100%" columnGap={8}>
<ModelList />
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
<DiffusersModelEdit />
) : (
<CheckpointModelEdit />
)}
</Flex>
</ModalBody>
<ModalFooter />
</ModalContent>
</Modal>
</>
);
}

View File

@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { forEach, isString } from 'lodash-es';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useListModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
@ -23,18 +23,18 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
const { data: pipelineModels, isLoading } = useListModelsQuery({ const { data: mainModels, isLoading } = useListModelsQuery({
model_type: 'main', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {
if (!pipelineModels) { if (!mainModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model) { if (!model) {
return; return;
} }
@ -47,11 +47,11 @@ const ModelSelect = () => {
}); });
return data; return data;
}, [pipelineModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId], () => mainModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId] [mainModels?.entities, selectedModelId]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -65,20 +65,18 @@ const ModelSelect = () => {
); );
useEffect(() => { useEffect(() => {
// If the selected model is not in the list of models, select the first one if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
// Handles first-run setting of models, and the user deleting the previously-selected model
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return; return;
} }
const firstModel = pipelineModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
return; return;
} }
handleChangeModel(firstModel); handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]); }, [handleChangeModel, mainModels?.ids, selectedModelId]);
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSelect

View File

@ -5,21 +5,18 @@ import StatusIndicator from './StatusIndicator';
import { Link } from '@chakra-ui/react'; import { Link } from '@chakra-ui/react';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa'; import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
import { MdSettings } from 'react-icons/md'; import { MdSettings } from 'react-icons/md';
import { useFeatureStatus } from '../hooks/useFeatureStatus';
import ColorModeButton from './ColorModeButton';
import HotkeysModal from './HotkeysModal/HotkeysModal'; import HotkeysModal from './HotkeysModal/HotkeysModal';
import InvokeAILogoComponent from './InvokeAILogoComponent'; import InvokeAILogoComponent from './InvokeAILogoComponent';
import LanguagePicker from './LanguagePicker'; import LanguagePicker from './LanguagePicker';
import ModelManagerModal from './ModelManager/ModelManagerModal';
import SettingsModal from './SettingsModal/SettingsModal'; import SettingsModal from './SettingsModal/SettingsModal';
import { useFeatureStatus } from '../hooks/useFeatureStatus';
import ColorModeButton from './ColorModeButton';
const SiteHeader = () => { const SiteHeader = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isModelManagerEnabled =
useFeatureStatus('modelManager').isFeatureEnabled;
const isLocalizationEnabled = const isLocalizationEnabled =
useFeatureStatus('localization').isFeatureEnabled; useFeatureStatus('localization').isFeatureEnabled;
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
@ -37,20 +34,6 @@ const SiteHeader = () => {
<Spacer /> <Spacer />
<StatusIndicator /> <StatusIndicator />
{isModelManagerEnabled && (
<ModelManagerModal>
<IAIIconButton
aria-label={t('modelManager.modelManager')}
tooltip={t('modelManager.modelManager')}
size="sm"
variant="link"
data-variant="link"
fontSize={20}
icon={<FaCube />}
/>
</ModelManagerModal>
)}
<HotkeysModal> <HotkeysModal>
<IAIIconButton <IAIIconButton
aria-label={t('common.hotkeysLabel')} aria-label={t('common.hotkeysLabel')}

View File

@ -1,10 +1,9 @@
import { Flex, Link } from '@chakra-ui/react'; import { Flex, Link } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa'; import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
import { MdSettings } from 'react-icons/md'; import { MdSettings } from 'react-icons/md';
import HotkeysModal from './HotkeysModal/HotkeysModal'; import HotkeysModal from './HotkeysModal/HotkeysModal';
import LanguagePicker from './LanguagePicker'; import LanguagePicker from './LanguagePicker';
import ModelManagerModal from './ModelManager/ModelManagerModal';
import SettingsModal from './SettingsModal/SettingsModal'; import SettingsModal from './SettingsModal/SettingsModal';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
@ -13,8 +12,6 @@ import { useFeatureStatus } from '../hooks/useFeatureStatus';
const SiteHeaderMenu = () => { const SiteHeaderMenu = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isModelManagerEnabled =
useFeatureStatus('modelManager').isFeatureEnabled;
const isLocalizationEnabled = const isLocalizationEnabled =
useFeatureStatus('localization').isFeatureEnabled; useFeatureStatus('localization').isFeatureEnabled;
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
@ -27,20 +24,6 @@ const SiteHeaderMenu = () => {
flexDirection={{ base: 'column', xl: 'row' }} flexDirection={{ base: 'column', xl: 'row' }}
gap={{ base: 4, xl: 1 }} gap={{ base: 4, xl: 1 }}
> >
{isModelManagerEnabled && (
<ModelManagerModal>
<IAIIconButton
aria-label={t('modelManager.modelManager')}
tooltip={t('modelManager.modelManager')}
size="sm"
variant="link"
data-variant="link"
fontSize={20}
icon={<FaCube />}
/>
</ModelManagerModal>
)}
<HotkeysModal> <HotkeysModal>
<IAIIconButton <IAIIconButton
aria-label={t('common.hotkeysLabel')} aria-label={t('common.hotkeysLabel')}

View File

@ -0,0 +1,89 @@
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { RootState } from 'app/store/store';
import { vaeSelected } from 'features/parameters/store/generationSlice';
import { MODEL_TYPE_MAP } from './ModelSelect';
const VAESelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: vaeModels } = useListModelsQuery({
model_type: 'vae',
});
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.vae
);
const data = useMemo(() => {
if (!vaeModels) {
return [];
}
const data: SelectItem[] = [
{
value: 'auto',
label: 'Automatic',
group: 'Default',
},
];
forEach(vaeModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [vaeModels]);
const selectedModel = useMemo(
() => vaeModels?.entities[selectedModelId],
[vaeModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaeSelected(v));
},
[dispatch]
);
useEffect(() => {
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
return;
}
handleChangeModel('auto');
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.vae')}
value={selectedModelId}
placeholder="Pick one"
data={data}
onChange={handleChangeModel}
/>
);
};
export default memo(VAESelect);

View File

@ -1,13 +1,14 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { useTranslation } from 'react-i18next';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { setShouldShowGallery } from 'features/ui/store/uiSlice'; import { setShouldShowGallery } from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { MdPhotoLibrary } from 'react-icons/md'; import { MdPhotoLibrary } from 'react-icons/md';
import { activeTabNameSelector, uiSelector } from '../store/uiSelectors'; import { activeTabNameSelector, uiSelector } from '../store/uiSelectors';
import { memo } from 'react'; import { NO_GALLERY_TABS } from './InvokeTabs';
const floatingGalleryButtonSelector = createSelector( const floatingGalleryButtonSelector = createSelector(
[activeTabNameSelector, uiSelector], [activeTabNameSelector, uiSelector],
@ -16,7 +17,9 @@ const floatingGalleryButtonSelector = createSelector(
return { return {
shouldPinGallery, shouldPinGallery,
shouldShowGalleryButton: !shouldShowGallery, shouldShowGalleryButton: NO_GALLERY_TABS.includes(activeTabName)
? false
: !shouldShowGallery,
}; };
}, },
{ memoizeOptions: { resultEqualityCheck: isEqual } } { memoizeOptions: { resultEqualityCheck: isEqual } }

View File

@ -9,35 +9,35 @@ import {
Tooltip, Tooltip,
VisuallyHidden, VisuallyHidden,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react'; import { ResourceKey } from 'i18next';
import { isEqual } from 'lodash-es';
import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCube, FaFont, FaImage } from 'react-icons/fa';
import { MdDeviceHub, MdGridOn } from 'react-icons/md'; import { MdDeviceHub, MdGridOn } from 'react-icons/md';
import { Panel, PanelGroup } from 'react-resizable-panels';
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
import { import {
activeTabIndexSelector, activeTabIndexSelector,
activeTabNameSelector, activeTabNameSelector,
} from '../store/uiSelectors'; } from '../store/uiSelectors';
import { useTranslation } from 'react-i18next'; import ImageTab from './tabs/ImageToImage/ImageToImageTab';
import { ResourceKey } from 'i18next'; import ModelManagerTab from './tabs/ModelManager/ModelManagerTab';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import NodesTab from './tabs/Nodes/NodesTab';
import { createSelector } from '@reduxjs/toolkit'; import ResizeHandle from './tabs/ResizeHandle';
import { configSelector } from 'features/system/store/configSelectors';
import { isEqual } from 'lodash-es';
import { Panel, PanelGroup } from 'react-resizable-panels';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import TextToImageTab from './tabs/TextToImage/TextToImageTab'; import TextToImageTab from './tabs/TextToImage/TextToImageTab';
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab'; import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
import NodesTab from './tabs/Nodes/NodesTab';
import { FaFont, FaImage, FaLayerGroup } from 'react-icons/fa';
import ResizeHandle from './tabs/ResizeHandle';
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
import BatchTab from './tabs/Batch/BatchTab';
export interface InvokeTabInfo { export interface InvokeTabInfo {
id: InvokeTabName; id: InvokeTabName;
@ -71,6 +71,11 @@ const tabs: InvokeTabInfo[] = [
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />, // icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
// content: <BatchTab />, // content: <BatchTab />,
// }, // },
{
id: 'modelManager',
icon: <Icon as={FaCube} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <ModelManagerTab />,
},
]; ];
const enabledTabsSelector = createSelector( const enabledTabsSelector = createSelector(
@ -87,6 +92,7 @@ const enabledTabsSelector = createSelector(
const MIN_GALLERY_WIDTH = 300; const MIN_GALLERY_WIDTH = 300;
const DEFAULT_GALLERY_PCT = 20; const DEFAULT_GALLERY_PCT = 20;
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
const InvokeTabs = () => { const InvokeTabs = () => {
const activeTab = useAppSelector(activeTabIndexSelector); const activeTab = useAppSelector(activeTabIndexSelector);
@ -198,26 +204,28 @@ const InvokeTabs = () => {
{tabPanels} {tabPanels}
</TabPanels> </TabPanels>
</Panel> </Panel>
{shouldPinGallery && shouldShowGallery && ( {shouldPinGallery &&
<> shouldShowGallery &&
<ResizeHandle /> !NO_GALLERY_TABS.includes(activeTabName) && (
<Panel <>
ref={galleryPanelRef} <ResizeHandle />
onResize={handleResizeGallery} <Panel
id="gallery" ref={galleryPanelRef}
order={3} onResize={handleResizeGallery}
defaultSize={ id="gallery"
galleryMinSizePct > DEFAULT_GALLERY_PCT order={3}
? galleryMinSizePct defaultSize={
: DEFAULT_GALLERY_PCT galleryMinSizePct > DEFAULT_GALLERY_PCT
} ? galleryMinSizePct
minSize={galleryMinSizePct} : DEFAULT_GALLERY_PCT
maxSize={50} }
> minSize={galleryMinSizePct}
<ImageGalleryContent /> maxSize={50}
</Panel> >
</> <ImageGalleryContent />
)} </Panel>
</>
)}
</PanelGroup> </PanelGroup>
</Tabs> </Tabs>
); );

View File

@ -1,20 +1,21 @@
import { memo } from 'react';
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex, useDisclosure } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
[uiSelector, generationSelector], [uiSelector, generationSelector],
@ -41,7 +42,7 @@ const ImageToImageTabCoreParameters = () => {
> >
{shouldUseSliders ? ( {shouldUseSliders ? (
<> <>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>
@ -58,7 +59,8 @@ const ImageToImageTabCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<ParamScheduler />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>

View File

@ -0,0 +1,81 @@
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n';
import { ReactNode, memo } from 'react';
import AddModelsPanel from './subpanels/AddModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
type ModelManagerTabInfo = {
id: ModelManagerTabName;
label: string;
content: ReactNode;
};
const modelManagerTabs: ModelManagerTabInfo[] = [
{
id: 'modelManager',
label: i18n.t('modelManager.modelManager'),
content: <ModelManagerPanel />,
},
{
id: 'addModels',
label: i18n.t('modelManager.addModel'),
content: <AddModelsPanel />,
},
{
id: 'mergeModels',
label: i18n.t('modelManager.mergeModels'),
content: <MergeModelsPanel />,
},
];
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
const ModelManagerTab = () => {
return (
<Tabs
isLazy
variant="invokeAI"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
>
{renderTabsList()}
{renderTabPanels()}
</Tabs>
);
};
export default memo(ModelManagerTab);

View File

@ -0,0 +1,55 @@
import { Divider, Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
export default function AddModelsPanel() {
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
sx={{
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
},
}}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
sx={{
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
},
}}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>
</Flex>
<Divider />
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
</Flex>
);
}

View File

@ -10,13 +10,11 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import React from 'react'; import React from 'react';
import SearchModels from './SearchModels';
// import { addNewModel } from 'app/socketio/actions'; // import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -24,12 +22,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Field, Formik } from 'formik'; import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { FieldInputProps, FormikProps } from 'formik';
import IAIForm from 'common/components/IAIForm'; import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper'; import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64; const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048; const MAX_MODEL_SIZE = 2048;

View File

@ -66,7 +66,7 @@ export default function AddDiffusersModel() {
}; };
return ( return (
<Flex> <Flex overflow="scroll" maxHeight={window.innerHeight - 270}>
<Formik <Formik
initialValues={addModelFormValues} initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler} onSubmit={addModelFormSubmitHandler}

View File

@ -0,0 +1,260 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { pickBy } from 'lodash-es';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
export default function MergeModelsPanel() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data } = useListModelsQuery({
model_type: 'main',
});
const diffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers'
);
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
);
const [modelThree, setModelThree] = useState<string>('none');
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
>('root');
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
useState<string>('');
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
};
return (
<Flex flexDirection="column" rowGap={4}>
<Flex
sx={{
flexDirection: 'column',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
<Text fontSize="sm" variant="subtext">
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
/>
<IAISelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelMergeInterp('weighted_sum');
}
}}
/>
</Flex>
<IAIInput
label={t('modelManager.mergedModelName')}
value={mergedModelName}
onChange={(e) => setMergedModelName(e.target.value)}
/>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<IAISlider
label={t('modelManager.alpha')}
min={0.01}
max={0.99}
step={0.01}
value={modelMergeAlpha}
onChange={(v) => setModelMergeAlpha(v)}
withInput
withReset
handleReset={() => setModelMergeAlpha(0.5)}
withSliderMarks
/>
<Text variant="subtext" fontSize="sm">
{t('modelManager.modelMergeAlphaHelp')}
</Text>
</Flex>
<Flex
sx={{
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
{t('modelManager.interpolationType')}
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
</Radio>
<Radio value="sigmoid">
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
</Radio>
<Radio value="inv_sigmoid">
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
</Radio>
</>
) : (
<Radio value="add_difference">
<Tooltip
label={t('modelManager.modelMergeInterpAddDifferenceHelp')}
>
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
</Tooltip>
</Radio>
)}
</Flex>
</RadioGroup>
</Flex>
<Flex
sx={{
flexDirection: 'column',
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
}}
>
<Flex columnGap={4}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.mergedModelSaveLocation')}
</Text>
<RadioGroup
value={modelMergeSaveLocType}
onChange={(v: 'root' | 'custom') => setModelMergeSaveLocType(v)}
>
<Flex columnGap={4}>
<Radio value="root">
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
</Radio>
<Radio value="custom">
<Text fontSize="sm">{t('modelManager.custom')}</Text>
</Radio>
</Flex>
</RadioGroup>
</Flex>
{modelMergeSaveLocType === 'custom' && (
<IAIInput
label={t('modelManager.mergedModelCustomSaveLocation')}
value={modelMergeCustomSaveLoc}
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
isChecked={modelMergeForce}
onChange={(e) => setModelMergeForce(e.target.checked)}
fontWeight="500"
/>
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
>
{t('modelManager.merge')}
</IAIButton>
</Flex>
);
}

View File

@ -0,0 +1,46 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useListModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: mainModels } = useListModelsQuery({
model_type: 'main',
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
</Flex>
);
}

View File

@ -0,0 +1,141 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: CheckpointModel;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm({
initialValues: {
name: retrievedModel.name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_format: 'checkpoint',
vae: retrievedModel.vae,
config: retrievedModel.config,
variant: retrievedModel.variant,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
</Text>
</Flex>
<ModelConvert model={retrievedModel} />
</Flex>
<Divider />
<Flex
flexDirection="column"
maxHeight={window.innerHeight - 270}
overflowY="scroll"
>
<form
onSubmit={checkpointEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')}
/>
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
{...checkpointEditForm.getInputProps('base_model')}
/>
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')}
/>
<IAIInput
label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')}
/>
<IAIInput
label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')}
/>
<IAIInput
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
<IAIButton disabled={isProcessing} type="submit">
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -0,0 +1,125 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form';
import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types';
type DiffusersModel =
| S<'StableDiffusion1ModelDiffusersConfig'>
| S<'StableDiffusion2ModelDiffusersConfig'>;
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: DiffusersModel;
};
const baseModelSelectData = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { retrievedModel, modelToEdit } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm({
initialValues: {
name: retrievedModel.name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_format: 'diffusers',
vae: retrievedModel.vae,
variant: retrievedModel.variant,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
</Text>
</Flex>
<Divider />
<form
onSubmit={diffusersEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')}
/>
<IAIMantineSelect
label={t('modelManager.baseModel')}
data={baseModelSelectData}
{...diffusersEditForm.getInputProps('base_model')}
/>
<IAIMantineSelect
label={t('modelManager.variant')}
data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')}
/>
<IAIInput
label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')}
/>
<IAIInput
label={t('modelManager.vaeLocation')}
{...diffusersEditForm.getInputProps('vae')}
/>
<IAIButton disabled={isProcessing} type="submit">
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -4,42 +4,28 @@ import {
Radio, Radio,
RadioGroup, RadioGroup,
Text, Text,
UnorderedList,
Tooltip, Tooltip,
UnorderedList,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions'; // import { convertToDiffusers } from 'app/socketio/actions';
import { RootState } from 'app/store/store'; import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { useState, useEffect } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
interface ModelConvertProps { interface ModelConvertProps {
model: string; model: CheckpointModel;
} }
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
const { model } = props; const { model } = props;
const model_list = useAppSelector(
(state: RootState) => state.system.model_list
);
const retrievedModel = model_list[model];
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const isConnected = useAppSelector(
(state: RootState) => state.system.isConnected
);
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) {
return ( return (
<IAIAlertDialog <IAIAlertDialog
title={`${t('modelManager.convert')} ${model}`} title={`${t('modelManager.convert')} ${model.name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler} cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}
@ -73,11 +59,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIButton <IAIButton
size={'sm'} size={'sm'}
aria-label={t('modelManager.convertToDiffusers')} aria-label={t('modelManager.convertToDiffusers')}
isDisabled={
retrievedModel.status === 'active' || isProcessing || !isConnected
}
className=" modal-close-btn" className=" modal-close-btn"
marginInlineEnd={8}
> >
🧨 {t('modelManager.convertToDiffusers')} 🧨 {t('modelManager.convertToDiffusers')}
</IAIButton> </IAIButton>

View File

@ -1,36 +1,14 @@
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react'; import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
import IAIInput from 'common/components/IAIInput';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import AddModel from './AddModel';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import MergeModels from './MergeModels';
import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import type { SystemState } from 'features/system/store/systemSlice';
import { isEqual, map } from 'lodash-es';
import React, { useMemo, useState, useTransition } from 'react';
import type { ChangeEvent, ReactNode } from 'react'; import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
const modelListSelector = createSelector( import { useListModelsQuery } from 'services/api/endpoints/models';
systemSelector,
(system: SystemState) => {
const models = map(system.model_list, (model, key) => {
return { name: key, ...model };
});
return models;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
function ModelFilterButton({ function ModelFilterButton({
label, label,
@ -58,7 +36,9 @@ function ModelFilterButton({
} }
const ModelList = () => { const ModelList = () => {
const models = useAppSelector(modelListSelector); const { data: mainModels } = useListModelsQuery({
model_type: 'main',
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false); const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
@ -90,43 +70,49 @@ const ModelList = () => {
const filteredModelListItemsToRender: ReactNode[] = []; const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = [];
models.forEach((model, i) => { if (!mainModels) return;
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push( filteredModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
if (model.format === isSelectedFilter) { if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push( localFilteredModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} }
} }
if (model.format !== 'diffusers') { if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push( ckptModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} else { } else {
diffusersModelListItemsToRender.push( diffusersModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} }
@ -142,6 +128,23 @@ const ModelList = () => {
<Flex flexDirection="column" rowGap={6}> <Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && ( {isSelectedFilter === 'all' && (
<> <>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box> <Box>
<Text <Text
sx={{ sx={{
@ -160,50 +163,26 @@ const ModelList = () => {
</Text> </Text>
{ckptModelListItemsToRender} {ckptModelListItemsToRender}
</Box> </Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
</> </>
)} )}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'diffusers' && ( {isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}> <Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender} {diffusersModelListItemsToRender}
</Flex> </Flex>
)} )}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex> </Flex>
); );
}, [models, searchText, t, isSelectedFilter]); }, [mainModels, searchText, t, isSelectedFilter]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
<Flex justifyContent="space-between" alignItems="center" gap={2}>
<Heading size="md">{t('modelManager.availableModels')}</Heading>
<Spacer />
<AddModel />
<MergeModels />
</Flex>
<IAIInput <IAIInput
onChange={handleSearchFilter} onChange={handleSearchFilter}
label={t('modelManager.search')} label={t('modelManager.search')}
@ -211,7 +190,7 @@ const ModelList = () => {
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={1} gap={4}
maxHeight={window.innerHeight - 240} maxHeight={window.innerHeight - 240}
overflow="scroll" overflow="scroll"
paddingInlineEnd={4} paddingInlineEnd={4}
@ -222,16 +201,16 @@ const ModelList = () => {
onClick={() => setIsSelectedFilter('all')} onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'} isActive={isSelectedFilter === 'all'}
/> />
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
<ModelFilterButton <ModelFilterButton
label={t('modelManager.diffusersModels')} label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')} onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'} isActive={isSelectedFilter === 'diffusers'}
/> />
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex> </Flex>
{renderModelList ? ( {renderModelList ? (

View File

@ -1,6 +1,6 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
import { ModelStatus } from 'app/types/invokeai';
// import { deleteModel, requestModelChange } from 'app/socketio/actions'; // import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ModelListItemProps = { type ModelListItemProps = {
modelKey: string;
name: string; name: string;
status: ModelStatus; description: string | undefined;
description: string;
}; };
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
@ -28,39 +28,24 @@ export default function ModelListItem(props: ModelListItemProps) {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { name, status, description } = props; const { modelKey, name, description } = props;
const handleChangeModel = () => {
dispatch(requestModelChange(name));
};
const openModelHandler = () => { const openModelHandler = () => {
dispatch(setOpenModel(name)); dispatch(setOpenModel(modelKey));
}; };
const handleModelDelete = () => { const handleModelDelete = () => {
dispatch(deleteModel(name)); dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null)); dispatch(setOpenModel(null));
}; };
const statusTextColor = () => {
switch (status) {
case 'active':
return 'ok.500';
case 'cached':
return 'warning.500';
case 'not loaded':
return 'inherit';
}
};
return ( return (
<Flex <Flex
alignItems="center" alignItems="center"
p={2} p={2}
borderRadius="base" borderRadius="base"
sx={ sx={
name === openModel modelKey === openModel
? { ? {
bg: 'accent.750', bg: 'accent.750',
_hover: { _hover: {
@ -81,15 +66,6 @@ export default function ModelListItem(props: ModelListItemProps) {
</Box> </Box>
<Spacer onClick={openModelHandler} cursor="pointer" /> <Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center"> <Flex gap={2} alignItems="center">
<Text color={statusTextColor()}>{status}</Text>
<Button
size="sm"
onClick={handleChangeModel}
isDisabled={status === 'active' || isProcessing || !isConnected}
>
{t('modelManager.load')}
</Button>
<IAIIconButton <IAIIconButton
icon={<EditIcon />} icon={<EditIcon />}
size="sm" size="sm"

View File

@ -1,17 +1,18 @@
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex, useDisclosure } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
uiSelector, uiSelector,
@ -37,7 +38,7 @@ const TextToImageTabCoreParameters = () => {
> >
{shouldUseSliders ? ( {shouldUseSliders ? (
<> <>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>
@ -54,7 +55,8 @@ const TextToImageTabCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<ParamScheduler />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>

View File

@ -1,18 +1,19 @@
import { memo } from 'react';
import { Box, Flex, useDisclosure } from '@chakra-ui/react'; import { Box, Flex, useDisclosure } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
uiSelector, uiSelector,
@ -38,7 +39,7 @@ const UnifiedCanvasCoreParameters = () => {
> >
{shouldUseSliders ? ( {shouldUseSliders ? (
<> <>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>
@ -55,7 +56,8 @@ const UnifiedCanvasCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<ParamSchedulerAndModel /> <ParamModelandVAE />
<ParamScheduler />
<Box pt={2}> <Box pt={2}>
<ParamSeedFull /> <ParamSeedFull />
</Box> </Box>

View File

@ -7,6 +7,7 @@ export const tabMap = [
'batch', 'batch',
// 'postprocessing', // 'postprocessing',
// 'training', // 'training',
'modelManager',
] as const; ] as const;
export type InvokeTabName = (typeof tabMap)[number]; export type InvokeTabName = (typeof tabMap)[number];

View File

@ -76,9 +76,16 @@ export type paths = {
*/ */
get: operations["list_models"]; get: operations["list_models"];
/** /**
* Import Model * Update Model
* @description Add Model * @description Add Model
*/ */
post: operations["update_model"];
};
"/api/v1/models/import": {
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL
*/
post: operations["import_model"]; post: operations["import_model"];
}; };
"/api/v1/models/{model_name}": { "/api/v1/models/{model_name}": {
@ -227,6 +234,23 @@ export type components = {
*/ */
b?: number; b?: number;
}; };
/** AddModelResult */
AddModelResult: {
/**
* Name
* @description The name of the model after import
*/
name: string;
/** @description The type of model */
model_type: components["schemas"]["ModelType"];
/** @description The base model */
base_model: components["schemas"]["BaseModelType"];
/**
* Config
* @description The configuration of the model
*/
config: components["schemas"]["ModelConfigBase"];
};
/** /**
* BaseModelType * BaseModelType
* @description An enumeration. * @description An enumeration.
@ -1030,7 +1054,7 @@ export type components = {
* @description The nodes in this graph * @description The nodes in this graph
*/ */
nodes?: { nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
}; };
/** /**
* Edges * Edges
@ -1073,7 +1097,7 @@ export type components = {
* @description The results of node executions * @description The results of node executions
*/ */
results: { results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
}; };
/** /**
* Errors * Errors
@ -1975,19 +1999,23 @@ export type components = {
*/ */
thumbnail_url: string; thumbnail_url: string;
}; };
/** ImportModelRequest */ /** ImportModelResponse */
ImportModelRequest: { ImportModelResponse: {
/** /**
* Name * Name
* @description A model path, repo_id or URL to import * @description The name of the imported model
*/ */
name: string; name: string;
/** /**
* Prediction Type * Info
* @description Prediction type for SDv2 checkpoint files * @description The model info
* @enum {string}
*/ */
prediction_type?: "epsilon" | "v_prediction" | "sample"; info: components["schemas"]["AddModelResult"];
/**
* Status
* @description The status of the API response
*/
status: string;
}; };
/** /**
* InfillColorInvocation * InfillColorInvocation
@ -2781,6 +2809,47 @@ export type components = {
*/ */
clip?: components["schemas"]["ClipField"]; clip?: components["schemas"]["ClipField"];
}; };
/**
* MainModelField
* @description Main model field
*/
MainModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/**
* MainModelLoaderInvocation
* @description Loads a main model, outputting its submodels.
*/
MainModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default main_model_loader
* @enum {string}
*/
type?: "main_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/** /**
* MaskFromAlphaInvocation * MaskFromAlphaInvocation
* @description Extracts the alpha channel of an image as a mask. * @description Extracts the alpha channel of an image as a mask.
@ -2974,6 +3043,16 @@ export type components = {
*/ */
thr_d?: number; thr_d?: number;
}; };
/** ModelConfigBase */
ModelConfigBase: {
/** Path */
path: string;
/** Description */
description?: string;
/** Model Format */
model_format?: string;
error?: components["schemas"]["ModelError"];
};
/** /**
* ModelError * ModelError
* @description An enumeration. * @description An enumeration.
@ -3036,7 +3115,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -3425,47 +3504,6 @@ export type components = {
*/ */
scribble?: boolean; scribble?: boolean;
}; };
/**
* PipelineModelField
* @description Pipeline model field
*/
PipelineModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/**
* PipelineModelLoaderInvocation
* @description Loads a pipeline model, outputting its submodels.
*/
PipelineModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default pipeline_model_loader
* @enum {string}
*/
type?: "pipeline_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["PipelineModelField"];
};
/** /**
* PromptCollectionOutput * PromptCollectionOutput
* @description Base class for invocations that output a collection of prompts * @description Base class for invocations that output a collection of prompts
@ -4266,6 +4304,19 @@ export type components = {
*/ */
level?: 2 | 4; level?: 2 | 4;
}; };
/**
* VAEModelField
* @description Vae model field
*/
VAEModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/** VaeField */ /** VaeField */
VaeField: { VaeField: {
/** /**
@ -4274,6 +4325,51 @@ export type components = {
*/ */
vae: components["schemas"]["ModelInfo"]; vae: components["schemas"]["ModelInfo"];
}; };
/**
* VaeLoaderInvocation
* @description Loads a VAE model, outputting a VaeLoaderOutput
*/
VaeLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default vae_loader
* @enum {string}
*/
type?: "vae_loader";
/**
* Vae Model
* @description The VAE to load
*/
vae_model: components["schemas"]["VAEModelField"];
};
/**
* VaeLoaderOutput
* @description Model loader output
*/
VaeLoaderOutput: {
/**
* Type
* @default vae_loader_output
* @enum {string}
*/
type?: "vae_loader_output";
/**
* Vae
* @description Vae model
*/
vae?: components["schemas"]["VaeField"];
};
/** VaeModelConfig */ /** VaeModelConfig */
VaeModelConfig: { VaeModelConfig: {
/** Name */ /** Name */
@ -4474,7 +4570,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -4511,7 +4607,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -4731,13 +4827,13 @@ export type operations = {
}; };
}; };
/** /**
* Import Model * Update Model
* @description Add Model * @description Add Model
*/ */
import_model: { update_model: {
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["ImportModelRequest"]; "application/json": components["schemas"]["CreateModelRequest"];
}; };
}; };
responses: { responses: {
@ -4755,6 +4851,36 @@ export type operations = {
}; };
}; };
}; };
/**
* Import Model
* @description Add a model using its local path, repo_id, or remote URL
*/
import_model: {
parameters: {
query: {
/** @description A model path, repo_id or URL to import */
name: string;
/** @description Prediction type for SDv2 checkpoint files */
prediction_type?: "v_prediction" | "epsilon" | "sample";
};
};
responses: {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["ImportModelResponse"];
};
};
/** @description The model could not be found */
404: never;
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/** /**
* Delete Model * Delete Model
* @description Delete Model * @description Delete Model

View File

@ -33,7 +33,8 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models // Models
export type ModelType = S<'ModelType'>; export type ModelType = S<'ModelType'>;
export type BaseModelType = S<'BaseModelType'>; export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = S<'PipelineModelField'>; export type MainModelField = S<'MainModelField'>;
export type VAEModelField = S<'VAEModelField'>;
export type ModelsList = S<'ModelsList'>; export type ModelsList = S<'ModelsList'>;
// Graphs // Graphs
@ -57,8 +58,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = N<'ControlNetInvocation'>; export type ControlNetInvocation = N<'ControlNetInvocation'>;

File diff suppressed because one or more lines are too long

View File

@ -1328,6 +1328,14 @@
react-remove-scroll "^2.5.5" react-remove-scroll "^2.5.5"
react-textarea-autosize "8.3.4" react-textarea-autosize "8.3.4"
"@mantine/form@^6.0.15":
version "6.0.15"
resolved "https://registry.yarnpkg.com/@mantine/form/-/form-6.0.15.tgz#e78d953669888e01d3778ee8f62d469a12668c42"
integrity sha512-Tz4AuZZ/ddGvEh5zJbDyi9PlGqTilJBdCjRGIgs3zn3hQsfg+ku7/NUR5zNB64dcWPJvGKc074y4iopNIl3FWQ==
dependencies:
fast-deep-equal "^3.1.3"
klona "^2.0.5"
"@mantine/hooks@^6.0.14": "@mantine/hooks@^6.0.14":
version "6.0.14" version "6.0.14"
resolved "https://registry.yarnpkg.com/@mantine/hooks/-/hooks-6.0.14.tgz#5f52a75cdd36b14c13a5ffeeedc510d73db76dc0" resolved "https://registry.yarnpkg.com/@mantine/hooks/-/hooks-6.0.14.tgz#5f52a75cdd36b14c13a5ffeeedc510d73db76dc0"
@ -4454,6 +4462,11 @@ klaw-sync@^6.0.0:
dependencies: dependencies:
graceful-fs "^4.1.11" graceful-fs "^4.1.11"
klona@^2.0.5:
version "2.0.6"
resolved "https://registry.yarnpkg.com/klona/-/klona-2.0.6.tgz#85bffbf819c03b2f53270412420a4555ef882e22"
integrity sha512-dhG34DXATL5hSxJbIexCft8FChFXtmskoZYnoPWjXQuebWYCNkVeV3KkGegCK9CP1oswI/vQibS2GY7Em/sJJA==
kolorist@^1.7.0: kolorist@^1.7.0:
version "1.8.0" version "1.8.0"
resolved "https://registry.yarnpkg.com/kolorist/-/kolorist-1.8.0.tgz#edddbbbc7894bc13302cdf740af6374d4a04743c" resolved "https://registry.yarnpkg.com/kolorist/-/kolorist-1.8.0.tgz#edddbbbc7894bc13302cdf740af6374d4a04743c"