mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
commit
92b163e95c
@ -2,17 +2,17 @@
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelRequest(BaseModel):
|
||||
name: str = Field(description="A model path, repo_id or URL to import")
|
||||
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
@ -86,7 +89,6 @@ async def list_models(
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
@ -109,27 +111,38 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
model_request: ImportModelRequest
|
||||
) -> None:
|
||||
""" Add Model """
|
||||
items_to_import = set([model_request.name])
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if len(installed_models) > 0:
|
||||
logger.info(f'Successfully imported {model_request.name}')
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {model_request.name} not imported')
|
||||
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
@ -30,7 +31,6 @@ class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
|
||||
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
@ -43,25 +43,26 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
@ -175,6 +176,14 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
|
||||
# TODO: ui rewrite
|
||||
@ -221,3 +230,56 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
return output
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
#fmt: on
|
||||
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
type: Literal["vae_loader"] = "vae_loader"
|
||||
|
||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
"type_hints": {
|
||||
"vae_model": "vae_model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VaeLoaderOutput(
|
||||
vae=VaeField(
|
||||
vae = ModelInfo(
|
||||
model_name = model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
@ -361,3 +384,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
@ -18,7 +18,7 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from ..util.logging import InvokeAILogger
|
||||
@ -166,17 +166,22 @@ class ModelInstall(object):
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_install(path)
|
||||
self.heuristic_import(path)
|
||||
job += 1
|
||||
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_install(self,
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Set[Path]:
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
|
||||
if not models_installed:
|
||||
models_installed = set()
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
@ -185,24 +190,24 @@ class ModelInstall(object):
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.add(self._install_path(path))
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
models_installed.add(self._install_path(path))
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_install(child, models_installed=models_installed)
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.add(self._install_repo(str(path)))
|
||||
models_installed.update(self._install_repo(str(path)))
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.add(self._install_url(model_path_id_or_url))
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
@ -214,24 +219,25 @@ class ModelInstall(object):
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
# logger.debug(f'Probing {path}')
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return path
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
|
||||
def _install_url(self, url: str)->Path:
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
@ -244,7 +250,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->Path:
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
|
||||
|
@ -233,14 +233,14 @@ import hashlib
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -278,8 +278,13 @@ class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after import")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
class ConfigMeta(BaseModel):
|
||||
version: str
|
||||
@ -571,13 +576,16 @@ class ModelManager(object):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory and the
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
|
||||
The returned dict has the same format as the dict returned by
|
||||
model_info().
|
||||
"""
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -601,12 +609,18 @@ class ModelManager(object):
|
||||
old_model_cache.unlink()
|
||||
|
||||
# remove in-memory cache
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
# note: it not guaranteed to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
base_model = base_model,
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
@ -729,7 +743,7 @@ class ModelManager(object):
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
def autoimport(self)->set[Path]:
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
'''
|
||||
@ -742,7 +756,6 @@ class ModelManager(object):
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
installed = set()
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
@ -756,13 +769,14 @@ class ModelManager(object):
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = set()
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
@ -772,7 +786,7 @@ class ModelManager(object):
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
|
||||
for f in files:
|
||||
@ -780,7 +794,7 @@ class ModelManager(object):
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
new_models_found.update(installer.heuristic_install(path))
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
@ -790,7 +804,7 @@ class ModelManager(object):
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Set[str]:
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -803,17 +817,20 @@ class ModelManager(object):
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
successfully_installed = set()
|
||||
successfully_installed = dict()
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_install(thing)
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
17
invokeai/frontend/web/dist/locales/en.json
vendored
17
invokeai/frontend/web/dist/locales/en.json
vendored
@ -24,16 +24,13 @@
|
||||
},
|
||||
"common": {
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"themeLabel": "Theme",
|
||||
"darkMode": "Dark Mode",
|
||||
"lightMode": "Light Mode",
|
||||
"languagePickerLabel": "Language",
|
||||
"reportBugLabel": "Report Bug",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
"settingsLabel": "Settings",
|
||||
"darkTheme": "Dark",
|
||||
"lightTheme": "Light",
|
||||
"greenTheme": "Green",
|
||||
"oceanTheme": "Ocean",
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
@ -55,6 +52,7 @@
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"modelmanager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@ -336,6 +334,7 @@
|
||||
"modelManager": {
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"vae": "VAE",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
@ -351,6 +350,7 @@
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"baseModel": "Base Model",
|
||||
"name": "Name",
|
||||
"nameValidationMsg": "Enter a name for your model",
|
||||
"description": "Description",
|
||||
@ -363,6 +363,7 @@
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"variant": "Variant",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
@ -524,7 +525,8 @@
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
@ -547,7 +549,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
@ -67,6 +67,7 @@
|
||||
"@fontsource-variable/inter": "^5.0.3",
|
||||
"@fontsource/inter": "^5.0.3",
|
||||
"@mantine/core": "^6.0.14",
|
||||
"@mantine/form": "^6.0.15",
|
||||
"@mantine/hooks": "^6.0.14",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
|
@ -53,6 +53,7 @@
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"batch": "Batch Manager",
|
||||
"modelmanager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@ -334,6 +335,7 @@
|
||||
"modelManager": {
|
||||
"modelManager": "Model Manager",
|
||||
"model": "Model",
|
||||
"vae": "VAE",
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
@ -349,6 +351,7 @@
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"baseModel": "Base Model",
|
||||
"name": "Name",
|
||||
"nameValidationMsg": "Enter a name for your model",
|
||||
"description": "Description",
|
||||
@ -361,6 +364,7 @@
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||
"variant": "Variant",
|
||||
"vaeRepoID": "VAE Repo ID",
|
||||
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||
"width": "Width",
|
||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import i18n from 'i18n';
|
||||
import { ReactNode, memo, useEffect } from 'react';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
|
@ -3,20 +3,21 @@ import { memo } from 'react';
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent';
|
||||
import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@ -152,6 +153,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -6,13 +6,13 @@ import {
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
}, [mainModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
|
||||
[mainModels?.entities, mainModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||
if (field.value && mainModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstModel);
|
||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||
}, [field.value, handleValueChanged, mainModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
|
@ -0,0 +1,97 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
VaeModelInputFieldTemplate,
|
||||
VaeModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const VaeModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
VaeModelInputFieldValue,
|
||||
VaeModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: vaeModels } = useListModelsQuery({
|
||||
model_type: 'vae',
|
||||
});
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => vaeModels?.entities[field.value ?? vaeModels.ids[0]],
|
||||
[vaeModels?.entities, vaeModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(vaeModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [vaeModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && vaeModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
handleValueChanged('auto');
|
||||
}, [field.value, handleValueChanged, vaeModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VaeModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ClipField: 'clip',
|
||||
VaeField: 'vae',
|
||||
model: 'model',
|
||||
vae_model: 'vae_model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
@ -116,6 +117,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
vae_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
|
@ -64,6 +64,7 @@ export type FieldType =
|
||||
| 'vae'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -91,6 +92,7 @@ export type InputFieldValue =
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -116,6 +118,7 @@ export type InputFieldTemplate =
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -228,6 +231,11 @@ export type ModelInputFieldValue = FieldValueBase & {
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -305,6 +313,21 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'conditioning';
|
||||
};
|
||||
|
||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'unet';
|
||||
};
|
||||
|
||||
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'clip';
|
||||
};
|
||||
|
||||
export type VaeInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'vae';
|
||||
};
|
||||
|
||||
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'control';
|
||||
@ -322,6 +345,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'vae_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -3,27 +3,28 @@ import { OpenAPIV3 } from 'openapi-types';
|
||||
import { FIELD_TYPE_MAP } from '../types/constants';
|
||||
import { isSchemaObject } from '../types/typeGuards';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
ArrayInputFieldTemplate,
|
||||
ItemInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
ColorInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
OutputFieldTemplate,
|
||||
TypeHints,
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerInputFieldTemplate,
|
||||
ItemInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
TypeHints,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
@ -175,6 +176,21 @@ const buildModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVaeModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): VaeModelInputFieldTemplate => {
|
||||
const template: VaeModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'vae_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -441,6 +457,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -75,6 +75,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -0,0 +1,68 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
import {
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
INPAINT,
|
||||
INPAINT_GRAPH,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
state: RootState
|
||||
): void => {
|
||||
const { vae: vaeId } = state.generation;
|
||||
const vae_model = modelIdToVAEModelField(vaeId);
|
||||
|
||||
if (vaeId !== 'auto') {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model,
|
||||
};
|
||||
}
|
||||
|
||||
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.id === INPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
@ -1,31 +1,26 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
ITERATE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -52,7 +47,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -81,9 +76,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -110,7 +105,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -120,7 +115,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -128,16 +123,6 @@ export const buildCanvasImageToImageGraph = (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
@ -170,17 +155,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -277,6 +252,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
|
||||
|
@ -1,23 +1,24 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
InpaintInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
INPAINT,
|
||||
INPAINT_GRAPH,
|
||||
ITERATE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
INPAINT_GRAPH,
|
||||
INPAINT,
|
||||
} from './constants';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -55,7 +56,7 @@ export const buildCanvasInpaintGraph = (
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
@ -101,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
@ -142,7 +143,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -152,7 +153,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -162,7 +163,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -170,16 +171,6 @@ export const buildCanvasInpaintGraph = (
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
@ -203,6 +194,9 @@ export const buildCanvasInpaintGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
|
@ -1,21 +1,18 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@ -38,7 +35,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -76,9 +73,9 @@ export const buildCanvasTextToImageGraph = (
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -109,7 +106,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -119,7 +116,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -129,7 +126,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -147,16 +144,6 @@ export const buildCanvasTextToImageGraph = (
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
@ -170,6 +157,9 @@ export const buildCanvasTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
|
||||
|
@ -1,28 +1,29 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageCollectionInvocation,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IterateInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
IMAGE_COLLECTION,
|
||||
IMAGE_COLLECTION_ITERATE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
IMAGE_COLLECTION,
|
||||
IMAGE_COLLECTION_ITERATE,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -69,7 +70,7 @@ export const buildLinearImageToImageGraph = (
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
@ -89,9 +90,9 @@ export const buildLinearImageToImageGraph = (
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -118,7 +119,7 @@ export const buildLinearImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -128,7 +129,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -136,16 +137,6 @@ export const buildLinearImageToImageGraph = (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
@ -176,19 +167,10 @@ export const buildLinearImageToImageGraph = (
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -322,6 +304,8 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
// Add VAE
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
|
@ -1,17 +1,18 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
export const buildLinearTextToImageGraph = (
|
||||
state: RootState
|
||||
@ -27,7 +28,7 @@ export const buildLinearTextToImageGraph = (
|
||||
height,
|
||||
} = state.generation;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -65,9 +66,9 @@ export const buildLinearTextToImageGraph = (
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -98,7 +99,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -108,7 +109,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -118,7 +119,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -136,16 +137,6 @@ export const buildLinearTextToImageGraph = (
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
@ -159,6 +150,9 @@ export const buildLinearTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
// Add Custom VAE Support
|
||||
addVAEToGraph(graph, state);
|
||||
|
||||
// add dynamic prompts, mutating `graph`
|
||||
addDynamicPromptsToGraph(graph, state);
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import { Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { InputFieldValue } from 'features/nodes/types/types';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
@ -27,7 +28,13 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
|
||||
if (field.type === 'model') {
|
||||
if (field.value) {
|
||||
return modelIdToPipelineModelField(field.value);
|
||||
return modelIdToMainModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'vae_model') {
|
||||
if (field.value) {
|
||||
return modelIdToVAEModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,8 @@ export const NOISE = 'noise';
|
||||
export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const VAE_LOADER = 'vae_loader';
|
||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
|
@ -0,0 +1,16 @@
|
||||
import { BaseModelType, MainModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: MainModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,18 +0,0 @@
|
||||
import { BaseModelType, PipelineModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a pipeline model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToPipelineModelField = (
|
||||
modelId: string
|
||||
): PipelineModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: PipelineModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -0,0 +1,16 @@
|
||||
import { BaseModelType, VAEModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToVAEModelField = (modelId: string): VAEModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: VAEModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,19 +1,19 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
import VAESelect from 'features/system/components/VAESelect';
|
||||
import { memo } from 'react';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
|
||||
const ParamSchedulerAndModel = () => {
|
||||
const ParamModelandVAE = () => {
|
||||
return (
|
||||
<Flex gap={3} w="full">
|
||||
<Box w="25rem">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
<Box w="full">
|
||||
<ModelSelect />
|
||||
</Box>
|
||||
<Box w="full">
|
||||
<VAESelect />
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSchedulerAndModel);
|
||||
export default memo(ParamModelandVAE);
|
@ -14,6 +14,7 @@ import {
|
||||
SeedParam,
|
||||
StepsParam,
|
||||
StrengthParam,
|
||||
VAEParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
|
||||
@ -47,6 +48,7 @@ export interface GenerationState {
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: ModelParam;
|
||||
vae: VAEParam;
|
||||
shouldUseSeamless: boolean;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
@ -81,6 +83,7 @@ export const initialGenerationState: GenerationState = {
|
||||
horizontalSymmetrySteps: 0,
|
||||
verticalSymmetrySteps: 0,
|
||||
model: '',
|
||||
vae: '',
|
||||
shouldUseSeamless: false,
|
||||
seamlessXAxis: true,
|
||||
seamlessYAxis: true,
|
||||
@ -216,6 +219,9 @@ export const generationSlice = createSlice({
|
||||
modelSelected: (state, action: PayloadAction<string>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
vaeSelected: (state, action: PayloadAction<string>) => {
|
||||
state.vae = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@ -260,6 +266,7 @@ export const {
|
||||
setVerticalSymmetrySteps,
|
||||
initialImageChanged,
|
||||
modelSelected,
|
||||
vaeSelected,
|
||||
setShouldUseNoiseSettings,
|
||||
setSeamless,
|
||||
setSeamlessXAxis,
|
||||
|
@ -135,6 +135,15 @@ export const zModel = z.string();
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ModelParam = z.infer<typeof zModel>;
|
||||
/**
|
||||
* Zod schema for VAE parameter
|
||||
* TODO: Make this a dynamically generated enum?
|
||||
*/
|
||||
export const zVAE = z.string();
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type VAEParam = z.infer<typeof zVAE>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
|
@ -1,125 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import AddCheckpointModel from './AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddDiffusersModel';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
|
||||
function AddModelBox({
|
||||
text,
|
||||
onClick,
|
||||
}: {
|
||||
text: string;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
width="50%"
|
||||
height={40}
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
onClick={onClick}
|
||||
as={Button}
|
||||
>
|
||||
<Text fontWeight="bold">{text}</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
export default function AddModel() {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const addNewModelUIOption = useAppSelector(
|
||||
(state: RootState) => state.ui.addNewModelUIOption
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const addModelModalClose = () => {
|
||||
onClose();
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIButton
|
||||
aria-label={t('modelManager.addNewModel')}
|
||||
tooltip={t('modelManager.addNewModel')}
|
||||
onClick={onOpen}
|
||||
size="sm"
|
||||
>
|
||||
<Flex columnGap={2} alignItems="center">
|
||||
<FaPlus />
|
||||
{t('modelManager.addNew')}
|
||||
</Flex>
|
||||
</IAIButton>
|
||||
|
||||
<Modal
|
||||
isOpen={isOpen}
|
||||
onClose={addModelModalClose}
|
||||
size="3xl"
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent margin="auto">
|
||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
||||
{addNewModelUIOption !== null && (
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
position="absolute"
|
||||
variant="ghost"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={2}
|
||||
icon={<FaArrowLeft />}
|
||||
/>
|
||||
)}
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
{addNewModelUIOption == null && (
|
||||
<Flex columnGap={4}>
|
||||
<AddModelBox
|
||||
text={t('modelManager.addCheckpointModel')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||
/>
|
||||
<AddModelBox
|
||||
text={t('modelManager.addDiffuserModel')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||
</ModalBody>
|
||||
<ModalFooter />
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
@ -1,339 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import { isEqual, pickBy } from 'lodash-es';
|
||||
import ModelConvert from './ModelConvert';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { openModel, model_list } = system;
|
||||
return {
|
||||
model_list,
|
||||
openModel,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
|
||||
export default function CheckpointModelEdit() {
|
||||
const { openModel, model_list } = useAppSelector(selector);
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [editModelFormValues, setEditModelFormValues] =
|
||||
useState<InvokeModelConfigProps>({
|
||||
name: '',
|
||||
description: '',
|
||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
||||
weights: '',
|
||||
vae: '',
|
||||
width: 512,
|
||||
height: 512,
|
||||
default: false,
|
||||
format: 'ckpt',
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (openModel) {
|
||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
||||
return isEqual(key, openModel);
|
||||
});
|
||||
setEditModelFormValues({
|
||||
name: openModel,
|
||||
description: retrievedModel[openModel]?.description,
|
||||
config: retrievedModel[openModel]?.config,
|
||||
weights: retrievedModel[openModel]?.weights,
|
||||
vae: retrievedModel[openModel]?.vae,
|
||||
width: retrievedModel[openModel]?.width,
|
||||
height: retrievedModel[openModel]?.height,
|
||||
default: retrievedModel[openModel]?.default,
|
||||
format: 'ckpt',
|
||||
});
|
||||
}
|
||||
}, [model_list, openModel]);
|
||||
|
||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||
dispatch(
|
||||
addNewModel({
|
||||
...values,
|
||||
width: Number(values.width),
|
||||
height: Number(values.height),
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return openModel ? (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex alignItems="center" gap={4} justifyContent="space-between">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{openModel}
|
||||
</Text>
|
||||
<ModelConvert model={openModel} />
|
||||
</Flex>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
maxHeight={window.innerHeight - 270}
|
||||
overflowY="scroll"
|
||||
paddingInlineEnd={8}
|
||||
>
|
||||
<Formik
|
||||
enableReinitialize={true}
|
||||
initialValues={editModelFormValues}
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.description}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.weights}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
<HStack width="100%">
|
||||
{/* Width */}
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.width}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.height}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,281 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isEqual, pickBy } from 'lodash-es';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { openModel, model_list } = system;
|
||||
return {
|
||||
model_list,
|
||||
openModel,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export default function DiffusersModelEdit() {
|
||||
const { openModel, model_list } = useAppSelector(selector);
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [editModelFormValues, setEditModelFormValues] =
|
||||
useState<InvokeDiffusersModelConfigProps>({
|
||||
name: '',
|
||||
description: '',
|
||||
repo_id: '',
|
||||
path: '',
|
||||
vae: { repo_id: '', path: '' },
|
||||
default: false,
|
||||
format: 'diffusers',
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (openModel) {
|
||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
||||
return isEqual(key, openModel);
|
||||
});
|
||||
|
||||
setEditModelFormValues({
|
||||
name: openModel,
|
||||
description: retrievedModel[openModel]?.description,
|
||||
path:
|
||||
retrievedModel[openModel]?.path &&
|
||||
retrievedModel[openModel]?.path !== 'None'
|
||||
? retrievedModel[openModel]?.path
|
||||
: '',
|
||||
repo_id:
|
||||
retrievedModel[openModel]?.repo_id &&
|
||||
retrievedModel[openModel]?.repo_id !== 'None'
|
||||
? retrievedModel[openModel]?.repo_id
|
||||
: '',
|
||||
vae: {
|
||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
||||
? retrievedModel[openModel]?.vae?.repo_id
|
||||
: '',
|
||||
path: retrievedModel[openModel]?.vae?.path
|
||||
? retrievedModel[openModel]?.vae?.path
|
||||
: '',
|
||||
},
|
||||
default: retrievedModel[openModel]?.default,
|
||||
format: 'diffusers',
|
||||
});
|
||||
}
|
||||
}, [model_list, openModel]);
|
||||
|
||||
const editModelFormSubmitHandler = (
|
||||
values: InvokeDiffusersModelConfigProps
|
||||
) => {
|
||||
const diffusersModelToEdit = values;
|
||||
|
||||
if (values.path === '') delete diffusersModelToEdit.path;
|
||||
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
|
||||
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
|
||||
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
|
||||
|
||||
dispatch(addNewModel(values));
|
||||
};
|
||||
|
||||
return openModel ? (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex alignItems="center">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{openModel}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
|
||||
<Formik
|
||||
enableReinitialize={true}
|
||||
initialValues={editModelFormValues}
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.description}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Path */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.path && touched.path}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="path" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="path"
|
||||
name="path"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.path && touched.path ? (
|
||||
<IAIFormErrorMessage>{errors.path}</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Repo ID */}
|
||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||
{t('modelManager.repo_id')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="repo_id"
|
||||
name="repo_id"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.repo_id && touched.repo_id ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.repo_id}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.repoIDValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Path */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||
>
|
||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.path"
|
||||
name="vae.path"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae?.path && touched.vae?.path ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.vae?.path}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Repo ID */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||
>
|
||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||
{t('modelManager.vaeRepoID')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.repo_id"
|
||||
name="vae.repo_id"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||
<IAIFormErrorMessage>
|
||||
{errors.vae?.repo_id}
|
||||
</IAIFormErrorMessage>
|
||||
) : (
|
||||
<IAIFormHelperText>
|
||||
{t('modelManager.vaeRepoIDValidationMsg')}
|
||||
</IAIFormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
<IAIButton
|
||||
type="submit"
|
||||
className="modal-close-btn"
|
||||
isLoading={isProcessing}
|
||||
>
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,313 +0,0 @@
|
||||
import {
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
// import { mergeDiffusersModels } from 'app/socketio/actions';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { diffusersModelsSelector } from 'features/system/store/systemSelectors';
|
||||
import { useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
|
||||
export default function MergeModels() {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const diffusersModels = useAppSelector(diffusersModelsSelector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [modelOne, setModelOne] = useState<string>(
|
||||
Object.keys(diffusersModels)[0]
|
||||
);
|
||||
const [modelTwo, setModelTwo] = useState<string>(
|
||||
Object.keys(diffusersModels)[1]
|
||||
);
|
||||
const [modelThree, setModelThree] = useState<string>('none');
|
||||
|
||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||
|
||||
const [modelMergeInterp, setModelMergeInterp] = useState<
|
||||
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
||||
>('weighted_sum');
|
||||
|
||||
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
|
||||
'root' | 'custom'
|
||||
>('root');
|
||||
|
||||
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
|
||||
useState<string>('');
|
||||
|
||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||
|
||||
const modelOneList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelTwo && model !== modelThree
|
||||
);
|
||||
|
||||
const modelTwoList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelOne && model !== modelThree
|
||||
);
|
||||
|
||||
const modelThreeList = [
|
||||
{ key: t('modelManager.none'), value: 'none' },
|
||||
...Object.keys(diffusersModels)
|
||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
||||
.map((model) => ({ key: model, value: model })),
|
||||
];
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const mergeModelsHandler = () => {
|
||||
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
||||
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
||||
|
||||
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
||||
models_to_merge: modelsToMerge,
|
||||
merged_model_name:
|
||||
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
||||
alpha: modelMergeAlpha,
|
||||
interp: modelMergeInterp,
|
||||
model_merge_save_path:
|
||||
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||
force: modelMergeForce,
|
||||
};
|
||||
|
||||
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIButton onClick={onOpen} size="sm">
|
||||
<Flex columnGap={2} alignItems="center">
|
||||
{t('modelManager.mergeModels')}
|
||||
</Flex>
|
||||
</IAIButton>
|
||||
|
||||
<Modal
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
size="4xl"
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent fontFamily="Inter" margin="auto" paddingInlineEnd={4}>
|
||||
<ModalHeader>{t('modelManager.mergeModels')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
marginBottom: 4,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
rowGap: 1,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.modelMergeHeaderHelp2')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex columnGap={4}>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelOne')}
|
||||
validValues={modelOneList}
|
||||
onChange={(e) => setModelOne(e.target.value)}
|
||||
/>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelTwo')}
|
||||
validValues={modelTwoList}
|
||||
onChange={(e) => setModelTwo(e.target.value)}
|
||||
/>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelThree')}
|
||||
validValues={modelThreeList}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== 'none') {
|
||||
setModelThree(e.target.value);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree('none');
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelName')}
|
||||
value={mergedModelName}
|
||||
onChange={(e) => setMergedModelName(e.target.value)}
|
||||
/>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<IAISlider
|
||||
label={t('modelManager.alpha')}
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={modelMergeAlpha}
|
||||
onChange={(v) => setModelMergeAlpha(v)}
|
||||
withInput
|
||||
withReset
|
||||
handleReset={() => setModelMergeAlpha(0.5)}
|
||||
withSliderMarks
|
||||
/>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{t('modelManager.modelMergeAlphaHelp')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||
{t('modelManager.interpolationType')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeInterp}
|
||||
onChange={(
|
||||
v:
|
||||
| 'weighted_sum'
|
||||
| 'sigmoid'
|
||||
| 'inv_sigmoid'
|
||||
| 'add_difference'
|
||||
) => setModelMergeInterp(v)}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
{modelThree === 'none' ? (
|
||||
<>
|
||||
<Radio value="weighted_sum">
|
||||
<Text fontSize="sm">
|
||||
{t('modelManager.weightedSum')}
|
||||
</Text>
|
||||
</Radio>
|
||||
<Radio value="sigmoid">
|
||||
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
||||
</Radio>
|
||||
<Radio value="inv_sigmoid">
|
||||
<Text fontSize="sm">
|
||||
{t('modelManager.inverseSigmoid')}
|
||||
</Text>
|
||||
</Radio>
|
||||
</>
|
||||
) : (
|
||||
<Radio value="add_difference">
|
||||
<Tooltip
|
||||
label={t(
|
||||
'modelManager.modelMergeInterpAddDifferenceHelp'
|
||||
)}
|
||||
>
|
||||
<Text fontSize="sm">
|
||||
{t('modelManager.addDifference')}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
)}
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||
{t('modelManager.mergedModelSaveLocation')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeSaveLocType}
|
||||
onChange={(v: 'root' | 'custom') =>
|
||||
setModelMergeSaveLocType(v)
|
||||
}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
<Radio value="root">
|
||||
<Text fontSize="sm">
|
||||
{t('modelManager.invokeAIFolder')}
|
||||
</Text>
|
||||
</Radio>
|
||||
|
||||
<Radio value="custom">
|
||||
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
{modelMergeSaveLocType === 'custom' && (
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelCustomSaveLocation')}
|
||||
value={modelMergeCustomSaveLoc}
|
||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.ignoreMismatch')}
|
||||
isChecked={modelMergeForce}
|
||||
onChange={(e) => setModelMergeForce(e.target.checked)}
|
||||
fontWeight="500"
|
||||
/>
|
||||
|
||||
<IAIButton
|
||||
onClick={mergeModelsHandler}
|
||||
isLoading={isProcessing}
|
||||
isDisabled={
|
||||
modelMergeSaveLocType === 'custom' &&
|
||||
modelMergeCustomSaveLoc === ''
|
||||
}
|
||||
>
|
||||
{t('modelManager.merge')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
<ModalFooter />
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
import {
|
||||
Flex,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement } from 'react';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
import CheckpointModelEdit from './CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './DiffusersModelEdit';
|
||||
import ModelList from './ModelList';
|
||||
|
||||
type ModelManagerModalProps = {
|
||||
children: ReactElement;
|
||||
};
|
||||
|
||||
export default function ModelManagerModal({
|
||||
children,
|
||||
}: ModelManagerModalProps) {
|
||||
const {
|
||||
isOpen: isModelManagerModalOpen,
|
||||
onOpen: onModelManagerModalOpen,
|
||||
onClose: onModelManagerModalClose,
|
||||
} = useDisclosure();
|
||||
|
||||
const model_list = useAppSelector(
|
||||
(state: RootState) => state.system.model_list
|
||||
);
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
onClick: onModelManagerModalOpen,
|
||||
})}
|
||||
<Modal
|
||||
isOpen={isModelManagerModalOpen}
|
||||
onClose={onModelManagerModalClose}
|
||||
size="full"
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalCloseButton />
|
||||
<ModalHeader>{t('modelManager.modelManager')}</ModalHeader>
|
||||
<ModalBody>
|
||||
<Flex width="100%" columnGap={8}>
|
||||
<ModelList />
|
||||
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
|
||||
<DiffusersModelEdit />
|
||||
) : (
|
||||
<CheckpointModelEdit />
|
||||
)}
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
<ModalFooter />
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
@ -23,18 +23,18 @@ const ModelSelect = () => {
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const { data: pipelineModels, isLoading } = useListModelsQuery({
|
||||
const { data: mainModels, isLoading } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
@ -47,11 +47,11 @@ const ModelSelect = () => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
}, [mainModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[selectedModelId],
|
||||
[pipelineModels?.entities, selectedModelId]
|
||||
() => mainModels?.entities[selectedModelId],
|
||||
[mainModels?.entities, selectedModelId]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
@ -65,20 +65,18 @@ const ModelSelect = () => {
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// If the selected model is not in the list of models, select the first one
|
||||
// Handles first-run setting of models, and the user deleting the previously-selected model
|
||||
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleChangeModel(firstModel);
|
||||
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||
}, [handleChangeModel, mainModels?.ids, selectedModelId]);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSelect
|
||||
|
@ -5,21 +5,18 @@ import StatusIndicator from './StatusIndicator';
|
||||
import { Link } from '@chakra-ui/react';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||
import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||
import { MdSettings } from 'react-icons/md';
|
||||
import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
||||
import ColorModeButton from './ColorModeButton';
|
||||
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
||||
import InvokeAILogoComponent from './InvokeAILogoComponent';
|
||||
import LanguagePicker from './LanguagePicker';
|
||||
import ModelManagerModal from './ModelManager/ModelManagerModal';
|
||||
import SettingsModal from './SettingsModal/SettingsModal';
|
||||
import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
||||
import ColorModeButton from './ColorModeButton';
|
||||
|
||||
const SiteHeader = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isModelManagerEnabled =
|
||||
useFeatureStatus('modelManager').isFeatureEnabled;
|
||||
const isLocalizationEnabled =
|
||||
useFeatureStatus('localization').isFeatureEnabled;
|
||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
||||
@ -37,20 +34,6 @@ const SiteHeader = () => {
|
||||
<Spacer />
|
||||
<StatusIndicator />
|
||||
|
||||
{isModelManagerEnabled && (
|
||||
<ModelManagerModal>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.modelManager')}
|
||||
tooltip={t('modelManager.modelManager')}
|
||||
size="sm"
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
icon={<FaCube />}
|
||||
/>
|
||||
</ModelManagerModal>
|
||||
)}
|
||||
|
||||
<HotkeysModal>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.hotkeysLabel')}
|
||||
|
@ -1,10 +1,9 @@
|
||||
import { Flex, Link } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaBug, FaCube, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||
import { FaBug, FaDiscord, FaGithub, FaKeyboard } from 'react-icons/fa';
|
||||
import { MdSettings } from 'react-icons/md';
|
||||
import HotkeysModal from './HotkeysModal/HotkeysModal';
|
||||
import LanguagePicker from './LanguagePicker';
|
||||
import ModelManagerModal from './ModelManager/ModelManagerModal';
|
||||
import SettingsModal from './SettingsModal/SettingsModal';
|
||||
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
@ -13,8 +12,6 @@ import { useFeatureStatus } from '../hooks/useFeatureStatus';
|
||||
const SiteHeaderMenu = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isModelManagerEnabled =
|
||||
useFeatureStatus('modelManager').isFeatureEnabled;
|
||||
const isLocalizationEnabled =
|
||||
useFeatureStatus('localization').isFeatureEnabled;
|
||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
||||
@ -27,20 +24,6 @@ const SiteHeaderMenu = () => {
|
||||
flexDirection={{ base: 'column', xl: 'row' }}
|
||||
gap={{ base: 4, xl: 1 }}
|
||||
>
|
||||
{isModelManagerEnabled && (
|
||||
<ModelManagerModal>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.modelManager')}
|
||||
tooltip={t('modelManager.modelManager')}
|
||||
size="sm"
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
icon={<FaCube />}
|
||||
/>
|
||||
</ModelManagerModal>
|
||||
)}
|
||||
|
||||
<HotkeysModal>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.hotkeysLabel')}
|
||||
|
@ -0,0 +1,89 @@
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import { vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { MODEL_TYPE_MAP } from './ModelSelect';
|
||||
|
||||
const VAESelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: vaeModels } = useListModelsQuery({
|
||||
model_type: 'vae',
|
||||
});
|
||||
|
||||
const selectedModelId = useAppSelector(
|
||||
(state: RootState) => state.generation.vae
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!vaeModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'auto',
|
||||
label: 'Automatic',
|
||||
group: 'Default',
|
||||
},
|
||||
];
|
||||
|
||||
forEach(vaeModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [vaeModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => vaeModels?.entities[selectedModelId],
|
||||
[vaeModels?.entities, selectedModelId]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(vaeSelected(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModelId && vaeModels?.ids.includes(selectedModelId)) {
|
||||
return;
|
||||
}
|
||||
handleChangeModel('auto');
|
||||
}, [handleChangeModel, vaeModels?.ids, selectedModelId]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.vae')}
|
||||
value={selectedModelId}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(VAESelect);
|
@ -1,13 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdPhotoLibrary } from 'react-icons/md';
|
||||
import { activeTabNameSelector, uiSelector } from '../store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
import { NO_GALLERY_TABS } from './InvokeTabs';
|
||||
|
||||
const floatingGalleryButtonSelector = createSelector(
|
||||
[activeTabNameSelector, uiSelector],
|
||||
@ -16,7 +17,9 @@ const floatingGalleryButtonSelector = createSelector(
|
||||
|
||||
return {
|
||||
shouldPinGallery,
|
||||
shouldShowGalleryButton: !shouldShowGallery,
|
||||
shouldShowGalleryButton: NO_GALLERY_TABS.includes(activeTabName)
|
||||
? false
|
||||
: !shouldShowGallery,
|
||||
};
|
||||
},
|
||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||
|
@ -9,35 +9,35 @@ import {
|
||||
Tooltip,
|
||||
VisuallyHidden,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCube, FaFont, FaImage } from 'react-icons/fa';
|
||||
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
|
||||
import {
|
||||
activeTabIndexSelector,
|
||||
activeTabNameSelector,
|
||||
} from '../store/uiSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
||||
import ModelManagerTab from './tabs/ModelManager/ModelManagerTab';
|
||||
import NodesTab from './tabs/Nodes/NodesTab';
|
||||
import ResizeHandle from './tabs/ResizeHandle';
|
||||
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
|
||||
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
|
||||
import NodesTab from './tabs/Nodes/NodesTab';
|
||||
import { FaFont, FaImage, FaLayerGroup } from 'react-icons/fa';
|
||||
import ResizeHandle from './tabs/ResizeHandle';
|
||||
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
||||
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
||||
import { useMinimumPanelSize } from '../hooks/useMinimumPanelSize';
|
||||
import BatchTab from './tabs/Batch/BatchTab';
|
||||
|
||||
export interface InvokeTabInfo {
|
||||
id: InvokeTabName;
|
||||
@ -71,6 +71,11 @@ const tabs: InvokeTabInfo[] = [
|
||||
// icon: <Icon as={FaLayerGroup} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
// content: <BatchTab />,
|
||||
// },
|
||||
{
|
||||
id: 'modelManager',
|
||||
icon: <Icon as={FaCube} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
content: <ModelManagerTab />,
|
||||
},
|
||||
];
|
||||
|
||||
const enabledTabsSelector = createSelector(
|
||||
@ -87,6 +92,7 @@ const enabledTabsSelector = createSelector(
|
||||
|
||||
const MIN_GALLERY_WIDTH = 300;
|
||||
const DEFAULT_GALLERY_PCT = 20;
|
||||
export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager'];
|
||||
|
||||
const InvokeTabs = () => {
|
||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
||||
@ -198,26 +204,28 @@ const InvokeTabs = () => {
|
||||
{tabPanels}
|
||||
</TabPanels>
|
||||
</Panel>
|
||||
{shouldPinGallery && shouldShowGallery && (
|
||||
<>
|
||||
<ResizeHandle />
|
||||
<Panel
|
||||
ref={galleryPanelRef}
|
||||
onResize={handleResizeGallery}
|
||||
id="gallery"
|
||||
order={3}
|
||||
defaultSize={
|
||||
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
||||
? galleryMinSizePct
|
||||
: DEFAULT_GALLERY_PCT
|
||||
}
|
||||
minSize={galleryMinSizePct}
|
||||
maxSize={50}
|
||||
>
|
||||
<ImageGalleryContent />
|
||||
</Panel>
|
||||
</>
|
||||
)}
|
||||
{shouldPinGallery &&
|
||||
shouldShowGallery &&
|
||||
!NO_GALLERY_TABS.includes(activeTabName) && (
|
||||
<>
|
||||
<ResizeHandle />
|
||||
<Panel
|
||||
ref={galleryPanelRef}
|
||||
onResize={handleResizeGallery}
|
||||
id="gallery"
|
||||
order={3}
|
||||
defaultSize={
|
||||
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
||||
? galleryMinSizePct
|
||||
: DEFAULT_GALLERY_PCT
|
||||
}
|
||||
minSize={galleryMinSizePct}
|
||||
maxSize={50}
|
||||
>
|
||||
<ImageGalleryContent />
|
||||
</Panel>
|
||||
</>
|
||||
)}
|
||||
</PanelGroup>
|
||||
</Tabs>
|
||||
);
|
||||
|
@ -1,20 +1,21 @@
|
||||
import { memo } from 'react';
|
||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
@ -41,7 +42,7 @@ const ImageToImageTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
@ -58,7 +59,8 @@ const ImageToImageTabCoreParameters = () => {
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<ParamScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
|
@ -0,0 +1,81 @@
|
||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
|
||||
import i18n from 'i18n';
|
||||
import { ReactNode, memo } from 'react';
|
||||
import AddModelsPanel from './subpanels/AddModelsPanel';
|
||||
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
||||
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
||||
|
||||
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
|
||||
|
||||
type ModelManagerTabInfo = {
|
||||
id: ModelManagerTabName;
|
||||
label: string;
|
||||
content: ReactNode;
|
||||
};
|
||||
|
||||
const modelManagerTabs: ModelManagerTabInfo[] = [
|
||||
{
|
||||
id: 'modelManager',
|
||||
label: i18n.t('modelManager.modelManager'),
|
||||
content: <ModelManagerPanel />,
|
||||
},
|
||||
{
|
||||
id: 'addModels',
|
||||
label: i18n.t('modelManager.addModel'),
|
||||
content: <AddModelsPanel />,
|
||||
},
|
||||
{
|
||||
id: 'mergeModels',
|
||||
label: i18n.t('modelManager.mergeModels'),
|
||||
content: <MergeModelsPanel />,
|
||||
},
|
||||
];
|
||||
|
||||
const renderTabsList = () => {
|
||||
const modelManagerTabListsToRender: ReactNode[] = [];
|
||||
modelManagerTabs.forEach((modelManagerTab) => {
|
||||
modelManagerTabListsToRender.push(
|
||||
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
|
||||
);
|
||||
});
|
||||
|
||||
return (
|
||||
<TabList
|
||||
sx={{
|
||||
w: '100%',
|
||||
color: 'base.200',
|
||||
flexDirection: 'row',
|
||||
borderBottomWidth: 2,
|
||||
borderColor: 'accent.700',
|
||||
}}
|
||||
>
|
||||
{modelManagerTabListsToRender}
|
||||
</TabList>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTabPanels = () => {
|
||||
const modelManagerTabPanelsToRender: ReactNode[] = [];
|
||||
modelManagerTabs.forEach((modelManagerTab) => {
|
||||
modelManagerTabPanelsToRender.push(
|
||||
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
|
||||
);
|
||||
});
|
||||
|
||||
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
|
||||
};
|
||||
|
||||
const ModelManagerTab = () => {
|
||||
return (
|
||||
<Tabs
|
||||
isLazy
|
||||
variant="invokeAI"
|
||||
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
|
||||
>
|
||||
{renderTabsList()}
|
||||
{renderTabPanels()}
|
||||
</Tabs>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelManagerTab);
|
@ -0,0 +1,55 @@
|
||||
import { Divider, Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
|
||||
|
||||
export default function AddModelsPanel() {
|
||||
const addNewModelUIOption = useAppSelector(
|
||||
(state: RootState) => state.ui.addNewModelUIOption
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4}>
|
||||
<Flex columnGap={4}>
|
||||
<IAIButton
|
||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||
sx={{
|
||||
backgroundColor:
|
||||
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
|
||||
'&:hover': {
|
||||
backgroundColor:
|
||||
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('modelManager.addCheckpointModel')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||
sx={{
|
||||
backgroundColor:
|
||||
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
|
||||
'&:hover': {
|
||||
backgroundColor:
|
||||
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
|
||||
},
|
||||
}}
|
||||
>
|
||||
{t('modelManager.addDiffuserModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
|
||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -10,13 +10,11 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import React from 'react';
|
||||
|
||||
import SearchModels from './SearchModels';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -24,12 +22,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import SearchModels from './SearchModels';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
@ -66,7 +66,7 @@ export default function AddDiffusersModel() {
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}>
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
@ -0,0 +1,260 @@
|
||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export default function MergeModelsPanel() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const diffusersModels = pickBy(
|
||||
data?.entities,
|
||||
(value, _) => value?.model_format === 'diffusers'
|
||||
);
|
||||
|
||||
const [modelOne, setModelOne] = useState<string>(
|
||||
Object.keys(diffusersModels)[0]
|
||||
);
|
||||
const [modelTwo, setModelTwo] = useState<string>(
|
||||
Object.keys(diffusersModels)[1]
|
||||
);
|
||||
const [modelThree, setModelThree] = useState<string>('none');
|
||||
|
||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||
|
||||
const [modelMergeInterp, setModelMergeInterp] = useState<
|
||||
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
||||
>('weighted_sum');
|
||||
|
||||
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
|
||||
'root' | 'custom'
|
||||
>('root');
|
||||
|
||||
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] =
|
||||
useState<string>('');
|
||||
|
||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||
|
||||
const modelOneList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelTwo && model !== modelThree
|
||||
);
|
||||
|
||||
const modelTwoList = Object.keys(diffusersModels).filter(
|
||||
(model) => model !== modelOne && model !== modelThree
|
||||
);
|
||||
|
||||
const modelThreeList = [
|
||||
{ key: t('modelManager.none'), value: 'none' },
|
||||
...Object.keys(diffusersModels)
|
||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
||||
.map((model) => ({ key: model, value: model })),
|
||||
];
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const mergeModelsHandler = () => {
|
||||
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
|
||||
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
|
||||
|
||||
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
|
||||
models_to_merge: modelsToMerge,
|
||||
merged_model_name:
|
||||
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
|
||||
alpha: modelMergeAlpha,
|
||||
interp: modelMergeInterp,
|
||||
model_merge_save_path:
|
||||
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
|
||||
force: modelMergeForce,
|
||||
};
|
||||
|
||||
dispatch(mergeDiffusersModels(mergeModelsInfo));
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
rowGap: 1,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.modelMergeHeaderHelp2')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex columnGap={4}>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelOne')}
|
||||
validValues={modelOneList}
|
||||
onChange={(e) => setModelOne(e.target.value)}
|
||||
/>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelTwo')}
|
||||
validValues={modelTwoList}
|
||||
onChange={(e) => setModelTwo(e.target.value)}
|
||||
/>
|
||||
<IAISelect
|
||||
label={t('modelManager.modelThree')}
|
||||
validValues={modelThreeList}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== 'none') {
|
||||
setModelThree(e.target.value);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree('none');
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelName')}
|
||||
value={mergedModelName}
|
||||
onChange={(e) => setMergedModelName(e.target.value)}
|
||||
/>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<IAISlider
|
||||
label={t('modelManager.alpha')}
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={modelMergeAlpha}
|
||||
onChange={(v) => setModelMergeAlpha(v)}
|
||||
withInput
|
||||
withReset
|
||||
handleReset={() => setModelMergeAlpha(0.5)}
|
||||
withSliderMarks
|
||||
/>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{t('modelManager.modelMergeAlphaHelp')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500} fontSize="sm" variant="subtext">
|
||||
{t('modelManager.interpolationType')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeInterp}
|
||||
onChange={(
|
||||
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
|
||||
) => setModelMergeInterp(v)}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
{modelThree === 'none' ? (
|
||||
<>
|
||||
<Radio value="weighted_sum">
|
||||
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
||||
</Radio>
|
||||
<Radio value="sigmoid">
|
||||
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
||||
</Radio>
|
||||
<Radio value="inv_sigmoid">
|
||||
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
|
||||
</Radio>
|
||||
</>
|
||||
) : (
|
||||
<Radio value="add_difference">
|
||||
<Tooltip
|
||||
label={t('modelManager.modelMergeInterpAddDifferenceHelp')}
|
||||
>
|
||||
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
)}
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
gap: 4,
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||
{t('modelManager.mergedModelSaveLocation')}
|
||||
</Text>
|
||||
<RadioGroup
|
||||
value={modelMergeSaveLocType}
|
||||
onChange={(v: 'root' | 'custom') => setModelMergeSaveLocType(v)}
|
||||
>
|
||||
<Flex columnGap={4}>
|
||||
<Radio value="root">
|
||||
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
|
||||
</Radio>
|
||||
|
||||
<Radio value="custom">
|
||||
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
{modelMergeSaveLocType === 'custom' && (
|
||||
<IAIInput
|
||||
label={t('modelManager.mergedModelCustomSaveLocation')}
|
||||
value={modelMergeCustomSaveLoc}
|
||||
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<IAISimpleCheckbox
|
||||
label={t('modelManager.ignoreMismatch')}
|
||||
isChecked={modelMergeForce}
|
||||
onChange={(e) => setModelMergeForce(e.target.checked)}
|
||||
fontWeight="500"
|
||||
/>
|
||||
|
||||
<IAIButton
|
||||
onClick={mergeModelsHandler}
|
||||
isLoading={isProcessing}
|
||||
isDisabled={
|
||||
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
|
||||
}
|
||||
>
|
||||
{t('modelManager.merge')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
);
|
||||
|
||||
const renderModelEditTabs = () => {
|
||||
if (!openModel || !mainModels) return;
|
||||
|
||||
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
|
||||
return (
|
||||
<DiffusersModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={mainModels['entities'][openModel]}
|
||||
key={openModel}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<CheckpointModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={mainModels['entities'][openModel]}
|
||||
key={openModel}
|
||||
/>
|
||||
);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<Flex width="100%" columnGap={8}>
|
||||
<ModelList />
|
||||
{renderModelEditTabs()}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -0,0 +1,141 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { S } from 'services/api/types';
|
||||
import ModelConvert from './ModelConvert';
|
||||
|
||||
const baseModelSelectData = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
];
|
||||
|
||||
const variantSelectData = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
{ value: 'inpaint', label: 'Inpaint' },
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
export type CheckpointModel =
|
||||
| S<'StableDiffusion1ModelCheckpointConfig'>
|
||||
| S<'StableDiffusion2ModelCheckpointConfig'>;
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: CheckpointModel;
|
||||
};
|
||||
|
||||
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const { modelToEdit, retrievedModel } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const checkpointEditForm = useForm({
|
||||
initialValues: {
|
||||
name: retrievedModel.name,
|
||||
base_model: retrievedModel.base_model,
|
||||
type: 'main',
|
||||
path: retrievedModel.path,
|
||||
description: retrievedModel.description,
|
||||
model_format: 'checkpoint',
|
||||
vae: retrievedModel.vae,
|
||||
config: retrievedModel.config,
|
||||
variant: retrievedModel.variant,
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values) => {
|
||||
console.log(values);
|
||||
};
|
||||
|
||||
return modelToEdit ? (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{retrievedModel.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||
</Text>
|
||||
</Flex>
|
||||
<ModelConvert model={retrievedModel} />
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
maxHeight={window.innerHeight - 270}
|
||||
overflowY="scroll"
|
||||
>
|
||||
<form
|
||||
onSubmit={checkpointEditForm.onSubmit((values) =>
|
||||
editModelFormSubmitHandler(values)
|
||||
)}
|
||||
>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIInput
|
||||
label={t('modelManager.name')}
|
||||
{...checkpointEditForm.getInputProps('name')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.description')}
|
||||
{...checkpointEditForm.getInputProps('description')}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.baseModel')}
|
||||
data={baseModelSelectData}
|
||||
{...checkpointEditForm.getInputProps('base_model')}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.variant')}
|
||||
data={variantSelectData}
|
||||
{...checkpointEditForm.getInputProps('variant')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.modelLocation')}
|
||||
{...checkpointEditForm.getInputProps('path')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.vaeLocation')}
|
||||
{...checkpointEditForm.getInputProps('vae')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.config')}
|
||||
{...checkpointEditForm.getInputProps('config')}
|
||||
/>
|
||||
<IAIButton disabled={isProcessing} type="submit">
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -0,0 +1,125 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useForm } from '@mantine/form';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { S } from 'services/api/types';
|
||||
|
||||
type DiffusersModel =
|
||||
| S<'StableDiffusion1ModelDiffusersConfig'>
|
||||
| S<'StableDiffusion2ModelDiffusersConfig'>;
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: DiffusersModel;
|
||||
};
|
||||
|
||||
const baseModelSelectData = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
];
|
||||
|
||||
const variantSelectData = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
{ value: 'inpaint', label: 'Inpaint' },
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
const { retrievedModel, modelToEdit } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const diffusersEditForm = useForm({
|
||||
initialValues: {
|
||||
name: retrievedModel.name,
|
||||
base_model: retrievedModel.base_model,
|
||||
type: 'main',
|
||||
path: retrievedModel.path,
|
||||
description: retrievedModel.description,
|
||||
model_format: 'diffusers',
|
||||
vae: retrievedModel.vae,
|
||||
variant: retrievedModel.variant,
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values) => {
|
||||
console.log(values);
|
||||
};
|
||||
|
||||
return modelToEdit ? (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{retrievedModel.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form
|
||||
onSubmit={diffusersEditForm.onSubmit((values) =>
|
||||
editModelFormSubmitHandler(values)
|
||||
)}
|
||||
>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIInput
|
||||
label={t('modelManager.name')}
|
||||
{...diffusersEditForm.getInputProps('name')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.description')}
|
||||
{...diffusersEditForm.getInputProps('description')}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.baseModel')}
|
||||
data={baseModelSelectData}
|
||||
{...diffusersEditForm.getInputProps('base_model')}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label={t('modelManager.variant')}
|
||||
data={variantSelectData}
|
||||
{...diffusersEditForm.getInputProps('variant')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.modelLocation')}
|
||||
{...diffusersEditForm.getInputProps('path')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.vaeLocation')}
|
||||
{...diffusersEditForm.getInputProps('vae')}
|
||||
/>
|
||||
<IAIButton disabled={isProcessing} type="submit">
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -4,42 +4,28 @@ import {
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
UnorderedList,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@chakra-ui/react';
|
||||
// import { convertToDiffusers } from 'app/socketio/actions';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { CheckpointModel } from './CheckpointModelEdit';
|
||||
|
||||
interface ModelConvertProps {
|
||||
model: string;
|
||||
model: CheckpointModel;
|
||||
}
|
||||
|
||||
export default function ModelConvert(props: ModelConvertProps) {
|
||||
const { model } = props;
|
||||
|
||||
const model_list = useAppSelector(
|
||||
(state: RootState) => state.system.model_list
|
||||
);
|
||||
|
||||
const retrievedModel = model_list[model];
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const isConnected = useAppSelector(
|
||||
(state: RootState) => state.system.isConnected
|
||||
);
|
||||
|
||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||
|
||||
@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
title={`${t('modelManager.convert')} ${model}`}
|
||||
title={`${t('modelManager.convert')} ${model.name}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
cancelCallback={modelConvertCancelHandler}
|
||||
acceptButtonText={`${t('modelManager.convert')}`}
|
||||
@ -73,11 +59,7 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
<IAIButton
|
||||
size={'sm'}
|
||||
aria-label={t('modelManager.convertToDiffusers')}
|
||||
isDisabled={
|
||||
retrievedModel.status === 'active' || isProcessing || !isConnected
|
||||
}
|
||||
className=" modal-close-btn"
|
||||
marginInlineEnd={8}
|
||||
>
|
||||
🧨 {t('modelManager.convertToDiffusers')}
|
||||
</IAIButton>
|
@ -1,36 +1,14 @@
|
||||
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
|
||||
import AddModel from './AddModel';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import MergeModels from './MergeModels';
|
||||
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import type { SystemState } from 'features/system/store/systemSlice';
|
||||
import { isEqual, map } from 'lodash-es';
|
||||
|
||||
import React, { useMemo, useState, useTransition } from 'react';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
|
||||
const modelListSelector = createSelector(
|
||||
systemSelector,
|
||||
(system: SystemState) => {
|
||||
const models = map(system.model_list, (model, key) => {
|
||||
return { name: key, ...model };
|
||||
});
|
||||
return models;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
import React, { useMemo, useState, useTransition } from 'react';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
function ModelFilterButton({
|
||||
label,
|
||||
@ -58,7 +36,9 @@ function ModelFilterButton({
|
||||
}
|
||||
|
||||
const ModelList = () => {
|
||||
const models = useAppSelector(modelListSelector);
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||
|
||||
@ -90,43 +70,49 @@ const ModelList = () => {
|
||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||
|
||||
models.forEach((model, i) => {
|
||||
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
|
||||
if (!mainModels) return;
|
||||
|
||||
const modelList = mainModels.entities;
|
||||
|
||||
Object.keys(modelList).forEach((model, i) => {
|
||||
if (
|
||||
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
|
||||
) {
|
||||
filteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
modelKey={model}
|
||||
name={modelList[model].name}
|
||||
description={modelList[model].description}
|
||||
/>
|
||||
);
|
||||
if (model.format === isSelectedFilter) {
|
||||
if (modelList[model]?.model_format === isSelectedFilter) {
|
||||
localFilteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
modelKey={model}
|
||||
name={modelList[model].name}
|
||||
description={modelList[model].description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
if (model.format !== 'diffusers') {
|
||||
if (modelList[model]?.model_format !== 'diffusers') {
|
||||
ckptModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
modelKey={model}
|
||||
name={modelList[model].name}
|
||||
description={modelList[model].description}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
diffusersModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
name={model.name}
|
||||
status={model.status}
|
||||
description={model.description}
|
||||
modelKey={model}
|
||||
name={modelList[model].name}
|
||||
description={modelList[model].description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -142,6 +128,23 @@ const ModelList = () => {
|
||||
<Flex flexDirection="column" rowGap={6}>
|
||||
{isSelectedFilter === 'all' && (
|
||||
<>
|
||||
<Box>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: '500',
|
||||
py: 2,
|
||||
px: 4,
|
||||
mb: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'max-content',
|
||||
fontSize: 'sm',
|
||||
bg: 'base.750',
|
||||
}}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</Text>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Box>
|
||||
<Box>
|
||||
<Text
|
||||
sx={{
|
||||
@ -160,50 +163,26 @@ const ModelList = () => {
|
||||
</Text>
|
||||
{ckptModelListItemsToRender}
|
||||
</Box>
|
||||
<Box>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: '500',
|
||||
py: 2,
|
||||
px: 4,
|
||||
mb: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'max-content',
|
||||
fontSize: 'sm',
|
||||
bg: 'base.750',
|
||||
}}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</Text>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'ckpt' && (
|
||||
<Flex flexDirection="column" marginTop={4}>
|
||||
{ckptModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'diffusers' && (
|
||||
<Flex flexDirection="column" marginTop={4}>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'ckpt' && (
|
||||
<Flex flexDirection="column" marginTop={4}>
|
||||
{ckptModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, [models, searchText, t, isSelectedFilter]);
|
||||
}, [mainModels, searchText, t, isSelectedFilter]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||
<Flex justifyContent="space-between" alignItems="center" gap={2}>
|
||||
<Heading size="md">{t('modelManager.availableModels')}</Heading>
|
||||
<Spacer />
|
||||
<AddModel />
|
||||
<MergeModels />
|
||||
</Flex>
|
||||
|
||||
<IAIInput
|
||||
onChange={handleSearchFilter}
|
||||
label={t('modelManager.search')}
|
||||
@ -211,7 +190,7 @@ const ModelList = () => {
|
||||
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={1}
|
||||
gap={4}
|
||||
maxHeight={window.innerHeight - 240}
|
||||
overflow="scroll"
|
||||
paddingInlineEnd={4}
|
||||
@ -222,16 +201,16 @@ const ModelList = () => {
|
||||
onClick={() => setIsSelectedFilter('all')}
|
||||
isActive={isSelectedFilter === 'all'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.checkpointModels')}
|
||||
onClick={() => setIsSelectedFilter('ckpt')}
|
||||
isActive={isSelectedFilter === 'ckpt'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.diffusersModels')}
|
||||
onClick={() => setIsSelectedFilter('diffusers')}
|
||||
isActive={isSelectedFilter === 'diffusers'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.checkpointModels')}
|
||||
onClick={() => setIsSelectedFilter('ckpt')}
|
||||
isActive={isSelectedFilter === 'ckpt'}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
{renderModelList ? (
|
@ -1,6 +1,6 @@
|
||||
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
||||
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { ModelStatus } from 'app/types/invokeai';
|
||||
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
||||
|
||||
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ModelListItemProps = {
|
||||
modelKey: string;
|
||||
name: string;
|
||||
status: ModelStatus;
|
||||
description: string;
|
||||
description: string | undefined;
|
||||
};
|
||||
|
||||
export default function ModelListItem(props: ModelListItemProps) {
|
||||
@ -28,39 +28,24 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { name, status, description } = props;
|
||||
|
||||
const handleChangeModel = () => {
|
||||
dispatch(requestModelChange(name));
|
||||
};
|
||||
const { modelKey, name, description } = props;
|
||||
|
||||
const openModelHandler = () => {
|
||||
dispatch(setOpenModel(name));
|
||||
dispatch(setOpenModel(modelKey));
|
||||
};
|
||||
|
||||
const handleModelDelete = () => {
|
||||
dispatch(deleteModel(name));
|
||||
dispatch(deleteModel(modelKey));
|
||||
dispatch(setOpenModel(null));
|
||||
};
|
||||
|
||||
const statusTextColor = () => {
|
||||
switch (status) {
|
||||
case 'active':
|
||||
return 'ok.500';
|
||||
case 'cached':
|
||||
return 'warning.500';
|
||||
case 'not loaded':
|
||||
return 'inherit';
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex
|
||||
alignItems="center"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
sx={
|
||||
name === openModel
|
||||
modelKey === openModel
|
||||
? {
|
||||
bg: 'accent.750',
|
||||
_hover: {
|
||||
@ -81,15 +66,6 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
</Box>
|
||||
<Spacer onClick={openModelHandler} cursor="pointer" />
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text color={statusTextColor()}>{status}</Text>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleChangeModel}
|
||||
isDisabled={status === 'active' || isProcessing || !isConnected}
|
||||
>
|
||||
{t('modelManager.load')}
|
||||
</Button>
|
||||
|
||||
<IAIIconButton
|
||||
icon={<EditIcon />}
|
||||
size="sm"
|
@ -1,17 +1,18 @@
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { memo } from 'react';
|
||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
uiSelector,
|
||||
@ -37,7 +38,7 @@ const TextToImageTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
@ -54,7 +55,8 @@ const TextToImageTabCoreParameters = () => {
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<ParamScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
|
@ -1,18 +1,19 @@
|
||||
import { memo } from 'react';
|
||||
import { Box, Flex, useDisclosure } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
|
||||
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||
import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamModelandVAE from 'features/parameters/components/Parameters/Core/ParamModelandVAE';
|
||||
import ParamScheduler from 'features/parameters/components/Parameters/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
uiSelector,
|
||||
@ -38,7 +39,7 @@ const UnifiedCanvasCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
@ -55,7 +56,8 @@ const UnifiedCanvasCoreParameters = () => {
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
<ParamSchedulerAndModel />
|
||||
<ParamModelandVAE />
|
||||
<ParamScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
|
@ -7,6 +7,7 @@ export const tabMap = [
|
||||
'batch',
|
||||
// 'postprocessing',
|
||||
// 'training',
|
||||
'modelManager',
|
||||
] as const;
|
||||
|
||||
export type InvokeTabName = (typeof tabMap)[number];
|
||||
|
240
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
240
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -76,9 +76,16 @@ export type paths = {
|
||||
*/
|
||||
get: operations["list_models"];
|
||||
/**
|
||||
* Import Model
|
||||
* Update Model
|
||||
* @description Add Model
|
||||
*/
|
||||
post: operations["update_model"];
|
||||
};
|
||||
"/api/v1/models/import": {
|
||||
/**
|
||||
* Import Model
|
||||
* @description Add a model using its local path, repo_id, or remote URL
|
||||
*/
|
||||
post: operations["import_model"];
|
||||
};
|
||||
"/api/v1/models/{model_name}": {
|
||||
@ -227,6 +234,23 @@ export type components = {
|
||||
*/
|
||||
b?: number;
|
||||
};
|
||||
/** AddModelResult */
|
||||
AddModelResult: {
|
||||
/**
|
||||
* Name
|
||||
* @description The name of the model after import
|
||||
*/
|
||||
name: string;
|
||||
/** @description The type of model */
|
||||
model_type: components["schemas"]["ModelType"];
|
||||
/** @description The base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
/**
|
||||
* Config
|
||||
* @description The configuration of the model
|
||||
*/
|
||||
config: components["schemas"]["ModelConfigBase"];
|
||||
};
|
||||
/**
|
||||
* BaseModelType
|
||||
* @description An enumeration.
|
||||
@ -1030,7 +1054,7 @@ export type components = {
|
||||
* @description The nodes in this graph
|
||||
*/
|
||||
nodes?: {
|
||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||
};
|
||||
/**
|
||||
* Edges
|
||||
@ -1073,7 +1097,7 @@ export type components = {
|
||||
* @description The results of node executions
|
||||
*/
|
||||
results: {
|
||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||
};
|
||||
/**
|
||||
* Errors
|
||||
@ -1975,19 +1999,23 @@ export type components = {
|
||||
*/
|
||||
thumbnail_url: string;
|
||||
};
|
||||
/** ImportModelRequest */
|
||||
ImportModelRequest: {
|
||||
/** ImportModelResponse */
|
||||
ImportModelResponse: {
|
||||
/**
|
||||
* Name
|
||||
* @description A model path, repo_id or URL to import
|
||||
* @description The name of the imported model
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* Prediction Type
|
||||
* @description Prediction type for SDv2 checkpoint files
|
||||
* @enum {string}
|
||||
* Info
|
||||
* @description The model info
|
||||
*/
|
||||
prediction_type?: "epsilon" | "v_prediction" | "sample";
|
||||
info: components["schemas"]["AddModelResult"];
|
||||
/**
|
||||
* Status
|
||||
* @description The status of the API response
|
||||
*/
|
||||
status: string;
|
||||
};
|
||||
/**
|
||||
* InfillColorInvocation
|
||||
@ -2781,6 +2809,47 @@ export type components = {
|
||||
*/
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
};
|
||||
/**
|
||||
* MainModelField
|
||||
* @description Main model field
|
||||
*/
|
||||
MainModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
};
|
||||
/**
|
||||
* MainModelLoaderInvocation
|
||||
* @description Loads a main model, outputting its submodels.
|
||||
*/
|
||||
MainModelLoaderInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default main_model_loader
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "main_model_loader";
|
||||
/**
|
||||
* Model
|
||||
* @description The model to load
|
||||
*/
|
||||
model: components["schemas"]["MainModelField"];
|
||||
};
|
||||
/**
|
||||
* MaskFromAlphaInvocation
|
||||
* @description Extracts the alpha channel of an image as a mask.
|
||||
@ -2974,6 +3043,16 @@ export type components = {
|
||||
*/
|
||||
thr_d?: number;
|
||||
};
|
||||
/** ModelConfigBase */
|
||||
ModelConfigBase: {
|
||||
/** Path */
|
||||
path: string;
|
||||
/** Description */
|
||||
description?: string;
|
||||
/** Model Format */
|
||||
model_format?: string;
|
||||
error?: components["schemas"]["ModelError"];
|
||||
};
|
||||
/**
|
||||
* ModelError
|
||||
* @description An enumeration.
|
||||
@ -3036,7 +3115,7 @@ export type components = {
|
||||
/** ModelsList */
|
||||
ModelsList: {
|
||||
/** Models */
|
||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
||||
};
|
||||
/**
|
||||
* MultiplyInvocation
|
||||
@ -3425,47 +3504,6 @@ export type components = {
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
/**
|
||||
* PipelineModelField
|
||||
* @description Pipeline model field
|
||||
*/
|
||||
PipelineModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
};
|
||||
/**
|
||||
* PipelineModelLoaderInvocation
|
||||
* @description Loads a pipeline model, outputting its submodels.
|
||||
*/
|
||||
PipelineModelLoaderInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default pipeline_model_loader
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "pipeline_model_loader";
|
||||
/**
|
||||
* Model
|
||||
* @description The model to load
|
||||
*/
|
||||
model: components["schemas"]["PipelineModelField"];
|
||||
};
|
||||
/**
|
||||
* PromptCollectionOutput
|
||||
* @description Base class for invocations that output a collection of prompts
|
||||
@ -4266,6 +4304,19 @@ export type components = {
|
||||
*/
|
||||
level?: 2 | 4;
|
||||
};
|
||||
/**
|
||||
* VAEModelField
|
||||
* @description Vae model field
|
||||
*/
|
||||
VAEModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
};
|
||||
/** VaeField */
|
||||
VaeField: {
|
||||
/**
|
||||
@ -4274,6 +4325,51 @@ export type components = {
|
||||
*/
|
||||
vae: components["schemas"]["ModelInfo"];
|
||||
};
|
||||
/**
|
||||
* VaeLoaderInvocation
|
||||
* @description Loads a VAE model, outputting a VaeLoaderOutput
|
||||
*/
|
||||
VaeLoaderInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default vae_loader
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "vae_loader";
|
||||
/**
|
||||
* Vae Model
|
||||
* @description The VAE to load
|
||||
*/
|
||||
vae_model: components["schemas"]["VAEModelField"];
|
||||
};
|
||||
/**
|
||||
* VaeLoaderOutput
|
||||
* @description Model loader output
|
||||
*/
|
||||
VaeLoaderOutput: {
|
||||
/**
|
||||
* Type
|
||||
* @default vae_loader_output
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "vae_loader_output";
|
||||
/**
|
||||
* Vae
|
||||
* @description Vae model
|
||||
*/
|
||||
vae?: components["schemas"]["VaeField"];
|
||||
};
|
||||
/** VaeModelConfig */
|
||||
VaeModelConfig: {
|
||||
/** Name */
|
||||
@ -4474,7 +4570,7 @@ export type operations = {
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
@ -4511,7 +4607,7 @@ export type operations = {
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
@ -4731,13 +4827,13 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Import Model
|
||||
* Update Model
|
||||
* @description Add Model
|
||||
*/
|
||||
import_model: {
|
||||
update_model: {
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["ImportModelRequest"];
|
||||
"application/json": components["schemas"]["CreateModelRequest"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
@ -4755,6 +4851,36 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Import Model
|
||||
* @description Add a model using its local path, repo_id, or remote URL
|
||||
*/
|
||||
import_model: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description A model path, repo_id or URL to import */
|
||||
name: string;
|
||||
/** @description Prediction type for SDv2 checkpoint files */
|
||||
prediction_type?: "v_prediction" | "epsilon" | "sample";
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description The model imported successfully */
|
||||
201: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["ImportModelResponse"];
|
||||
};
|
||||
};
|
||||
/** @description The model could not be found */
|
||||
404: never;
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Delete Model
|
||||
* @description Delete Model
|
||||
|
@ -33,7 +33,8 @@ export type OffsetPaginatedResults_ImageDTO_ =
|
||||
// Models
|
||||
export type ModelType = S<'ModelType'>;
|
||||
export type BaseModelType = S<'BaseModelType'>;
|
||||
export type PipelineModelField = S<'PipelineModelField'>;
|
||||
export type MainModelField = S<'MainModelField'>;
|
||||
export type VAEModelField = S<'VAEModelField'>;
|
||||
export type ModelsList = S<'ModelsList'>;
|
||||
|
||||
// Graphs
|
||||
@ -57,8 +58,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
|
||||
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
||||
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
||||
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
||||
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
|
||||
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
||||
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
||||
|
File diff suppressed because one or more lines are too long
@ -1328,6 +1328,14 @@
|
||||
react-remove-scroll "^2.5.5"
|
||||
react-textarea-autosize "8.3.4"
|
||||
|
||||
"@mantine/form@^6.0.15":
|
||||
version "6.0.15"
|
||||
resolved "https://registry.yarnpkg.com/@mantine/form/-/form-6.0.15.tgz#e78d953669888e01d3778ee8f62d469a12668c42"
|
||||
integrity sha512-Tz4AuZZ/ddGvEh5zJbDyi9PlGqTilJBdCjRGIgs3zn3hQsfg+ku7/NUR5zNB64dcWPJvGKc074y4iopNIl3FWQ==
|
||||
dependencies:
|
||||
fast-deep-equal "^3.1.3"
|
||||
klona "^2.0.5"
|
||||
|
||||
"@mantine/hooks@^6.0.14":
|
||||
version "6.0.14"
|
||||
resolved "https://registry.yarnpkg.com/@mantine/hooks/-/hooks-6.0.14.tgz#5f52a75cdd36b14c13a5ffeeedc510d73db76dc0"
|
||||
@ -4454,6 +4462,11 @@ klaw-sync@^6.0.0:
|
||||
dependencies:
|
||||
graceful-fs "^4.1.11"
|
||||
|
||||
klona@^2.0.5:
|
||||
version "2.0.6"
|
||||
resolved "https://registry.yarnpkg.com/klona/-/klona-2.0.6.tgz#85bffbf819c03b2f53270412420a4555ef882e22"
|
||||
integrity sha512-dhG34DXATL5hSxJbIexCft8FChFXtmskoZYnoPWjXQuebWYCNkVeV3KkGegCK9CP1oswI/vQibS2GY7Em/sJJA==
|
||||
|
||||
kolorist@^1.7.0:
|
||||
version "1.8.0"
|
||||
resolved "https://registry.yarnpkg.com/kolorist/-/kolorist-1.8.0.tgz#edddbbbc7894bc13302cdf740af6374d4a04743c"
|
||||
|
Loading…
Reference in New Issue
Block a user